refactor(llm): route calls through resolved models

This commit is contained in:
Anish Sarkar 2026-06-10 21:48:23 +05:30
parent 8b59ca59c1
commit 62ff97c830
5 changed files with 209 additions and 423 deletions

View file

@ -450,10 +450,10 @@ async def _resolve_agent_billing_for_search_space(
thread_id: int | None = None,
) -> tuple[UUID, str, str]:
"""Resolve ``(owner_user_id, billing_tier, base_model)`` for the search-space
agent LLM.
chat model.
Used by Celery tasks (podcast generation, video presentation) to bill the
search-space owner's premium credit pool when the agent LLM is premium.
search-space owner's premium credit pool when the chat model is premium.
Resolution rules mirror chat at ``stream_new_chat.py:2294-2351``:

View file

@ -20,7 +20,11 @@ from typing import Any
from litellm import Router
from litellm.utils import ImageResponse
from app.services.provider_api_base import resolve_api_base
from app.services.model_resolver import (
NATIVE_PROVIDER_PREFIX,
native_connection_from_config,
to_litellm,
)
logger = logging.getLogger(__name__)
@ -30,17 +34,7 @@ IMAGE_GEN_AUTO_MODE_ID = 0
# Provider mapping for LiteLLM model string construction.
# Only includes providers that support image generation.
# See: https://docs.litellm.ai/docs/image_generation#supported-providers
IMAGE_GEN_PROVIDER_MAP = {
"OPENAI": "openai",
"AZURE_OPENAI": "azure",
"GOOGLE": "gemini", # Google AI Studio
"VERTEX_AI": "vertex_ai",
"BEDROCK": "bedrock", # AWS Bedrock
"RECRAFT": "recraft",
"OPENROUTER": "openrouter",
"XINFERENCE": "xinference",
"NSCALE": "nscale",
}
IMAGE_GEN_PROVIDER_MAP = NATIVE_PROVIDER_PREFIX
class ImageGenRouterService:
@ -153,38 +147,11 @@ class ImageGenRouterService:
if not config.get("model_name") or not config.get("api_key"):
return None
# Build model string
provider = config.get("provider", "").upper()
if config.get("custom_provider"):
provider_prefix = config["custom_provider"]
else:
provider_prefix = IMAGE_GEN_PROVIDER_MAP.get(provider, provider.lower())
model_string = f"{provider_prefix}/{config['model_name']}"
# Build litellm params
litellm_params: dict[str, Any] = {
"model": model_string,
"api_key": config.get("api_key"),
}
# Resolve ``api_base`` so deployments don't silently inherit
# ``AZURE_OPENAI_ENDPOINT`` / ``OPENAI_API_BASE`` and 404 against
# the wrong provider (see ``provider_api_base`` docstring).
api_base = resolve_api_base(
provider=provider,
provider_prefix=provider_prefix,
config_api_base=config.get("api_base"),
model_string, resolved_kwargs = to_litellm(
native_connection_from_config(config),
config["model_name"],
)
if api_base:
litellm_params["api_base"] = api_base
# Add api_version (required for Azure)
if config.get("api_version"):
litellm_params["api_version"] = config["api_version"]
# Add any additional litellm parameters
if config.get("litellm_params"):
litellm_params.update(config["litellm_params"])
litellm_params: dict[str, Any] = {"model": model_string, **resolved_kwargs}
# All configs use same alias "auto" for unified routing
deployment: dict[str, Any] = {

View file

@ -30,6 +30,11 @@ from litellm.exceptions import (
)
from pydantic import Field
from app.services.model_resolver import (
NATIVE_PROVIDER_PREFIX,
native_connection_from_config,
to_litellm,
)
from app.utils.perf import get_perf_logger
litellm.json_logs = False
@ -96,52 +101,8 @@ def _sanitize_content(content: Any) -> Any:
# Special ID for Auto mode - uses router for load balancing
AUTO_MODE_ID = 0
# Provider mapping for LiteLLM model string construction
PROVIDER_MAP = {
"OPENAI": "openai",
"ANTHROPIC": "anthropic",
"GROQ": "groq",
"COHERE": "cohere",
"GOOGLE": "gemini",
"OLLAMA": "ollama_chat",
"MISTRAL": "mistral",
"AZURE_OPENAI": "azure",
"OPENROUTER": "openrouter",
"COMETAPI": "cometapi",
"XAI": "xai",
"BEDROCK": "bedrock",
"AWS_BEDROCK": "bedrock", # Legacy support
"VERTEX_AI": "vertex_ai",
"TOGETHER_AI": "together_ai",
"FIREWORKS_AI": "fireworks_ai",
"REPLICATE": "replicate",
"PERPLEXITY": "perplexity",
"ANYSCALE": "anyscale",
"DEEPINFRA": "deepinfra",
"CEREBRAS": "cerebras",
"SAMBANOVA": "sambanova",
"AI21": "ai21",
"CLOUDFLARE": "cloudflare",
"DATABRICKS": "databricks",
"DEEPSEEK": "openai",
"ALIBABA_QWEN": "openai",
"MOONSHOT": "openai",
"ZHIPU": "openai",
"GITHUB_MODELS": "github",
"HUGGINGFACE": "huggingface",
"MINIMAX": "openai",
"CUSTOM": "custom",
}
# ``PROVIDER_DEFAULT_API_BASE`` and ``PROVIDER_KEY_DEFAULT_API_BASE`` were
# hoisted to ``app.services.provider_api_base`` so vision and image-gen
# call sites can share the exact same defense (OpenRouter / Groq / etc.
# 404-ing against an inherited Azure endpoint). Re-exported here for
# backward compatibility with any external import.
from app.services.provider_api_base import ( # noqa: E402
resolve_api_base,
)
# Historical export kept for callers that still import PROVIDER_MAP.
PROVIDER_MAP = NATIVE_PROVIDER_PREFIX
class LLMRouterService:
@ -420,38 +381,11 @@ class LLMRouterService:
if not config.get("model_name") or not config.get("api_key"):
return None
# Build model string
provider = config.get("provider", "").upper()
if config.get("custom_provider"):
provider_prefix = config["custom_provider"]
model_string = f"{provider_prefix}/{config['model_name']}"
else:
provider_prefix = PROVIDER_MAP.get(provider, provider.lower())
model_string = f"{provider_prefix}/{config['model_name']}"
# Build litellm params
litellm_params = {
"model": model_string,
"api_key": config.get("api_key"),
}
# Resolve ``api_base``. Config value wins; otherwise apply a
# provider-aware default so the deployment does not silently
# inherit unrelated env vars (e.g. ``AZURE_API_BASE``) and route
# requests to the wrong endpoint. See ``provider_api_base``
# docstring for the motivating bug (OpenRouter models 404-ing
# against an Azure endpoint).
api_base = resolve_api_base(
provider=provider,
provider_prefix=provider_prefix,
config_api_base=config.get("api_base"),
model_string, resolved_kwargs = to_litellm(
native_connection_from_config(config),
config["model_name"],
)
if api_base:
litellm_params["api_base"] = api_base
# Add any additional litellm parameters
if config.get("litellm_params"):
litellm_params.update(config["litellm_params"])
litellm_params = {"model": model_string, **resolved_kwargs}
# Extract rate limits if provided
deployment = {

View file

@ -6,9 +6,10 @@ from langchain_core.messages import HumanMessage
from langchain_litellm import ChatLiteLLM
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from sqlalchemy.orm import selectinload
from app.config import config
from app.db import NewLLMConfig, SearchSpace
from app.db import Model, SearchSpace
from app.services.llm_router_service import (
AUTO_MODE_ID,
ChatLiteLLMRouter,
@ -16,7 +17,7 @@ from app.services.llm_router_service import (
get_auto_mode_llm,
is_auto_mode,
)
from app.services.provider_api_base import resolve_api_base
from app.services.model_resolver import native_connection_from_config, to_litellm
from app.services.token_tracking_service import token_tracker
# Configure litellm to automatically drop unsupported parameters
@ -66,6 +67,29 @@ def _is_interactive_auth_provider(
return False
def _legacy_config_connection(
*,
provider: str,
model_name: str,
api_key: str | None,
api_base: str | None,
custom_provider: str | None = None,
litellm_params: dict | None = None,
api_version: str | None = None,
) -> tuple[str, dict]:
cfg = {
"provider": provider,
"model_name": model_name,
"api_key": api_key,
"api_base": api_base,
"custom_provider": custom_provider,
"api_version": api_version,
"litellm_params": litellm_params or {},
}
conn = native_connection_from_config(cfg)
return to_litellm(conn, model_name)
class LLMRole:
AGENT = "agent" # For agent/chat operations
@ -102,6 +126,60 @@ def get_global_llm_config(llm_config_id: int) -> dict | None:
return None
def get_global_model(model_id: int) -> dict | None:
return next((m for m in config.GLOBAL_MODELS if m.get("id") == model_id), None)
def get_global_connection(connection_id: int) -> dict | None:
return next(
(c for c in config.GLOBAL_CONNECTIONS if c.get("id") == connection_id),
None,
)
def _has_capability(model: dict | Model, capability: str) -> bool:
caps = (
model.get("capabilities", {})
if isinstance(model, dict)
else model.capabilities or {}
)
return bool(caps.get(capability))
def _chat_litellm_from_resolved(
*,
conn: dict | object,
model_id: str,
disable_streaming: bool = False,
) -> tuple[str, dict]:
model_string, resolved_kwargs = to_litellm(conn, model_id)
litellm_kwargs = {"model": model_string, **resolved_kwargs}
if disable_streaming:
litellm_kwargs["disable_streaming"] = True
return model_string, litellm_kwargs
async def _get_db_model(
session: AsyncSession,
model_id: int,
search_space: SearchSpace,
) -> Model | None:
result = await session.execute(
select(Model)
.options(selectinload(Model.connection))
.where(Model.id == model_id, Model.enabled.is_(True))
)
model = result.scalars().first()
if not model or not model.connection or not model.connection.enabled:
return None
conn = model.connection
if conn.search_space_id and conn.search_space_id != search_space.id:
return None
if conn.user_id and conn.user_id != search_space.user_id:
return None
return model
async def validate_llm_config(
provider: str,
model_name: str,
@ -146,62 +224,15 @@ async def validate_llm_config(
return False, msg
try:
# Build the model string for litellm
if custom_provider:
model_string = f"{custom_provider}/{model_name}"
else:
# Map provider enum to litellm format
provider_map = {
"OPENAI": "openai",
"ANTHROPIC": "anthropic",
"GROQ": "groq",
"COHERE": "cohere",
"GOOGLE": "gemini",
"OLLAMA": "ollama_chat",
"MISTRAL": "mistral",
"AZURE_OPENAI": "azure",
"OPENROUTER": "openrouter",
"COMETAPI": "cometapi",
"XAI": "xai",
"BEDROCK": "bedrock",
"AWS_BEDROCK": "bedrock", # Legacy support (backward compatibility)
"VERTEX_AI": "vertex_ai",
"TOGETHER_AI": "together_ai",
"FIREWORKS_AI": "fireworks_ai",
"REPLICATE": "replicate",
"PERPLEXITY": "perplexity",
"ANYSCALE": "anyscale",
"DEEPINFRA": "deepinfra",
"CEREBRAS": "cerebras",
"SAMBANOVA": "sambanova",
"AI21": "ai21",
"CLOUDFLARE": "cloudflare",
"DATABRICKS": "databricks",
# Chinese LLM providers
"DEEPSEEK": "openai",
"ALIBABA_QWEN": "openai",
"MOONSHOT": "openai",
"ZHIPU": "openai", # GLM needs special handling
"MINIMAX": "openai",
"GITHUB_MODELS": "github",
}
provider_prefix = provider_map.get(provider, provider.lower())
model_string = f"{provider_prefix}/{model_name}"
# Create ChatLiteLLM instance
litellm_kwargs = {
"model": model_string,
"api_key": api_key,
"timeout": 30, # Set a timeout for validation
}
# Add optional parameters
if api_base:
litellm_kwargs["api_base"] = api_base
# Add any additional litellm parameters
if litellm_params:
litellm_kwargs.update(litellm_params)
model_string, resolved_kwargs = _legacy_config_connection(
provider=provider,
model_name=model_name,
api_key=api_key,
api_base=api_base,
custom_provider=custom_provider,
litellm_params=litellm_params,
)
litellm_kwargs = {"model": model_string, **resolved_kwargs, "timeout": 30}
from app.agents.chat.runtime.llm_config import (
SanitizedChatLiteLLM,
@ -283,9 +314,9 @@ async def get_search_space_llm_instance(
logger.error(f"Search space {search_space_id} not found")
return None
# Get the appropriate LLM config ID based on role
# Get the appropriate model binding ID based on role
if role == LLMRole.AGENT:
llm_config_id = search_space.agent_llm_id
llm_config_id = search_space.chat_model_id
else:
logger.error(f"Invalid LLM role: {role}")
return None
@ -312,70 +343,26 @@ async def get_search_space_llm_instance(
logger.error(f"Failed to create ChatLiteLLMRouter: {e}")
return None
# Check if this is a global config (negative ID)
# Check if this is a global virtual model (negative ID)
if llm_config_id < 0:
global_config = get_global_llm_config(llm_config_id)
if not global_config:
logger.error(f"Global LLM config {llm_config_id} not found")
global_model = get_global_model(llm_config_id)
if not global_model or not _has_capability(global_model, "chat"):
logger.error(f"Global chat model {llm_config_id} not found")
return None
global_connection = get_global_connection(global_model["connection_id"])
if not global_connection:
logger.error(
"Global connection %s not found for model %s",
global_model["connection_id"],
llm_config_id,
)
return None
# Build model string for global config
if global_config.get("custom_provider"):
model_string = (
f"{global_config['custom_provider']}/{global_config['model_name']}"
)
else:
provider_map = {
"OPENAI": "openai",
"ANTHROPIC": "anthropic",
"GROQ": "groq",
"COHERE": "cohere",
"GOOGLE": "gemini",
"OLLAMA": "ollama_chat",
"MISTRAL": "mistral",
"AZURE_OPENAI": "azure",
"OPENROUTER": "openrouter",
"COMETAPI": "cometapi",
"XAI": "xai",
"BEDROCK": "bedrock",
"AWS_BEDROCK": "bedrock",
"VERTEX_AI": "vertex_ai",
"TOGETHER_AI": "together_ai",
"FIREWORKS_AI": "fireworks_ai",
"REPLICATE": "replicate",
"PERPLEXITY": "perplexity",
"ANYSCALE": "anyscale",
"DEEPINFRA": "deepinfra",
"CEREBRAS": "cerebras",
"SAMBANOVA": "sambanova",
"AI21": "ai21",
"CLOUDFLARE": "cloudflare",
"DATABRICKS": "databricks",
"DEEPSEEK": "openai",
"ALIBABA_QWEN": "openai",
"MOONSHOT": "openai",
"ZHIPU": "openai",
"MINIMAX": "openai",
}
provider_prefix = provider_map.get(
global_config["provider"], global_config["provider"].lower()
)
model_string = f"{provider_prefix}/{global_config['model_name']}"
# Create ChatLiteLLM instance from global config
litellm_kwargs = {
"model": model_string,
"api_key": global_config["api_key"],
}
if global_config.get("api_base"):
litellm_kwargs["api_base"] = global_config["api_base"]
if global_config.get("litellm_params"):
litellm_kwargs.update(global_config["litellm_params"])
if disable_streaming:
litellm_kwargs["disable_streaming"] = True
_, litellm_kwargs = _chat_litellm_from_resolved(
conn=global_connection,
model_id=global_model["model_id"],
disable_streaming=disable_streaming,
)
from app.agents.chat.runtime.llm_config import (
SanitizedChatLiteLLM,
@ -383,80 +370,18 @@ async def get_search_space_llm_instance(
return SanitizedChatLiteLLM(**litellm_kwargs)
# Get the LLM configuration from database (NewLLMConfig)
result = await session.execute(
select(NewLLMConfig).where(
NewLLMConfig.id == llm_config_id,
NewLLMConfig.search_space_id == search_space_id,
)
)
llm_config = result.scalars().first()
if not llm_config:
model = await _get_db_model(session, llm_config_id, search_space)
if not model or not _has_capability(model, "chat"):
logger.error(
f"LLM config {llm_config_id} not found in search space {search_space_id}"
f"Chat model {llm_config_id} not found in search space {search_space_id}"
)
return None
# Build the model string for litellm
if llm_config.custom_provider:
model_string = f"{llm_config.custom_provider}/{llm_config.model_name}"
else:
# Map provider enum to litellm format
provider_map = {
"OPENAI": "openai",
"ANTHROPIC": "anthropic",
"GROQ": "groq",
"COHERE": "cohere",
"GOOGLE": "gemini",
"OLLAMA": "ollama_chat",
"MISTRAL": "mistral",
"AZURE_OPENAI": "azure",
"OPENROUTER": "openrouter",
"COMETAPI": "cometapi",
"XAI": "xai",
"BEDROCK": "bedrock",
"AWS_BEDROCK": "bedrock",
"VERTEX_AI": "vertex_ai",
"TOGETHER_AI": "together_ai",
"FIREWORKS_AI": "fireworks_ai",
"REPLICATE": "replicate",
"PERPLEXITY": "perplexity",
"ANYSCALE": "anyscale",
"DEEPINFRA": "deepinfra",
"CEREBRAS": "cerebras",
"SAMBANOVA": "sambanova",
"AI21": "ai21",
"CLOUDFLARE": "cloudflare",
"DATABRICKS": "databricks",
"DEEPSEEK": "openai",
"ALIBABA_QWEN": "openai",
"MOONSHOT": "openai",
"ZHIPU": "openai",
"MINIMAX": "openai",
"GITHUB_MODELS": "github",
}
provider_prefix = provider_map.get(
llm_config.provider.value, llm_config.provider.value.lower()
)
model_string = f"{provider_prefix}/{llm_config.model_name}"
# Create ChatLiteLLM instance
litellm_kwargs = {
"model": model_string,
"api_key": llm_config.api_key,
}
# Add optional parameters
if llm_config.api_base:
litellm_kwargs["api_base"] = llm_config.api_base
# Add any additional litellm parameters
if llm_config.litellm_params:
litellm_kwargs.update(llm_config.litellm_params)
if disable_streaming:
litellm_kwargs["disable_streaming"] = True
_, litellm_kwargs = _chat_litellm_from_resolved(
conn=model.connection,
model_id=model.model_id,
disable_streaming=disable_streaming,
)
from app.agents.chat.runtime.llm_config import (
SanitizedChatLiteLLM,
@ -474,7 +399,7 @@ async def get_search_space_llm_instance(
async def get_agent_llm(
session: AsyncSession, search_space_id: int, disable_streaming: bool = False
) -> ChatLiteLLM | ChatLiteLLMRouter | None:
"""Get the search space's agent LLM instance for chat operations."""
"""Get the search space's chat model instance."""
return await get_search_space_llm_instance(
session,
search_space_id,
@ -488,22 +413,19 @@ async def get_vision_llm(
) -> ChatLiteLLM | ChatLiteLLMRouter | None:
"""Get the search space's vision LLM instance for screenshot analysis.
Resolves from the dedicated VisionLLMConfig system:
Resolves from the new connection/model role bindings:
- Auto mode (ID 0): VisionLLMRouterService
- Global (negative ID): YAML configs
- DB (positive ID): VisionLLMConfig table
- Global (negative ID): virtual GLOBAL models from YAML
- DB (positive ID): Model + Connection tables
Premium global configs are wrapped in :class:`QuotaCheckedVisionLLM`
so each ``ainvoke`` debits the search-space owner's premium credit
pool. User-owned BYOK configs and free global configs are returned
unwrapped they don't consume premium credit (issue M).
"""
from app.db import VisionLLMConfig
from app.services.quota_checked_vision_llm import QuotaCheckedVisionLLM
from app.services.vision_llm_router_service import (
VISION_PROVIDER_MAP,
VisionLLMRouterService,
get_global_vision_llm_config,
is_vision_auto_mode,
)
@ -516,13 +438,43 @@ async def get_vision_llm(
logger.error(f"Search space {search_space_id} not found")
return None
config_id = search_space.vision_llm_config_id
owner_user_id = search_space.user_id
# Prefer the selected chat model when it is vision-capable.
chat_model_id = search_space.chat_model_id
if chat_model_id and chat_model_id != AUTO_MODE_ID:
if chat_model_id < 0:
chat_model = get_global_model(chat_model_id)
if chat_model and _has_capability(chat_model, "vision"):
global_connection = get_global_connection(chat_model["connection_id"])
if global_connection:
model_string, litellm_kwargs = _chat_litellm_from_resolved(
conn=global_connection,
model_id=chat_model["model_id"],
)
from app.agents.chat.runtime.llm_config import (
SanitizedChatLiteLLM,
)
return SanitizedChatLiteLLM(**litellm_kwargs)
else:
chat_model = await _get_db_model(session, chat_model_id, search_space)
if chat_model and _has_capability(chat_model, "vision"):
_, litellm_kwargs = _chat_litellm_from_resolved(
conn=chat_model.connection,
model_id=chat_model.model_id,
)
from app.agents.chat.runtime.llm_config import (
SanitizedChatLiteLLM,
)
return SanitizedChatLiteLLM(**litellm_kwargs)
config_id = search_space.vision_model_id
if config_id is None:
logger.error(f"No vision LLM configured for search space {search_space_id}")
return None
owner_user_id = search_space.user_id
if is_vision_auto_mode(config_id):
if not VisionLLMRouterService.is_initialized():
logger.error(
@ -546,34 +498,24 @@ async def get_vision_llm(
return None
if config_id < 0:
global_cfg = get_global_vision_llm_config(config_id)
if not global_cfg:
logger.error(f"Global vision LLM config {config_id} not found")
global_model = get_global_model(config_id)
if not global_model or not _has_capability(global_model, "vision"):
logger.error(f"Global vision model {config_id} not found")
return None
if global_cfg.get("custom_provider"):
provider_prefix = global_cfg["custom_provider"]
model_string = f"{provider_prefix}/{global_cfg['model_name']}"
else:
provider_prefix = VISION_PROVIDER_MAP.get(
global_cfg["provider"].upper(),
global_cfg["provider"].lower(),
global_connection = get_global_connection(global_model["connection_id"])
if not global_connection:
logger.error(
"Global connection %s not found for model %s",
global_model["connection_id"],
config_id,
)
model_string = f"{provider_prefix}/{global_cfg['model_name']}"
return None
litellm_kwargs = {
"model": model_string,
"api_key": global_cfg["api_key"],
}
api_base = resolve_api_base(
provider=global_cfg.get("provider"),
provider_prefix=provider_prefix,
config_api_base=global_cfg.get("api_base"),
model_string, litellm_kwargs = _chat_litellm_from_resolved(
conn=global_connection,
model_id=global_model["model_id"],
)
if api_base:
litellm_kwargs["api_base"] = api_base
if global_cfg.get("litellm_params"):
litellm_kwargs.update(global_cfg["litellm_params"])
from app.agents.chat.runtime.llm_config import (
SanitizedChatLiteLLM,
@ -581,7 +523,7 @@ async def get_vision_llm(
inner_llm = SanitizedChatLiteLLM(**litellm_kwargs)
billing_tier = str(global_cfg.get("billing_tier", "free")).lower()
billing_tier = str(global_model.get("billing_tier", "free")).lower()
if billing_tier == "premium":
return QuotaCheckedVisionLLM(
inner_llm,
@ -589,47 +531,23 @@ async def get_vision_llm(
search_space_id=search_space_id,
billing_tier=billing_tier,
base_model=model_string,
quota_reserve_tokens=global_cfg.get("quota_reserve_tokens"),
quota_reserve_tokens=global_model.get("catalog", {}).get(
"quota_reserve_tokens"
),
)
return inner_llm
# User-owned (positive ID) BYOK configs — always free.
result = await session.execute(
select(VisionLLMConfig).where(
VisionLLMConfig.id == config_id,
VisionLLMConfig.search_space_id == search_space_id,
)
)
vision_cfg = result.scalars().first()
if not vision_cfg:
model = await _get_db_model(session, config_id, search_space)
if not model or not _has_capability(model, "vision"):
logger.error(
f"Vision LLM config {config_id} not found in search space {search_space_id}"
f"Vision model {config_id} not found in search space {search_space_id}"
)
return None
if vision_cfg.custom_provider:
provider_prefix = vision_cfg.custom_provider
model_string = f"{provider_prefix}/{vision_cfg.model_name}"
else:
provider_prefix = VISION_PROVIDER_MAP.get(
vision_cfg.provider.value.upper(),
vision_cfg.provider.value.lower(),
)
model_string = f"{provider_prefix}/{vision_cfg.model_name}"
litellm_kwargs = {
"model": model_string,
"api_key": vision_cfg.api_key,
}
api_base = resolve_api_base(
provider=vision_cfg.provider.value,
provider_prefix=provider_prefix,
config_api_base=vision_cfg.api_base,
_, litellm_kwargs = _chat_litellm_from_resolved(
conn=model.connection,
model_id=model.model_id,
)
if api_base:
litellm_kwargs["api_base"] = api_base
if vision_cfg.litellm_params:
litellm_kwargs.update(vision_cfg.litellm_params)
from app.agents.chat.runtime.llm_config import (
SanitizedChatLiteLLM,

View file

@ -3,29 +3,17 @@ from typing import Any
from litellm import Router
from app.services.provider_api_base import resolve_api_base
from app.services.model_resolver import (
NATIVE_PROVIDER_PREFIX,
native_connection_from_config,
to_litellm,
)
logger = logging.getLogger(__name__)
VISION_AUTO_MODE_ID = 0
VISION_PROVIDER_MAP = {
"OPENAI": "openai",
"ANTHROPIC": "anthropic",
"GOOGLE": "gemini",
"AZURE_OPENAI": "azure",
"VERTEX_AI": "vertex_ai",
"BEDROCK": "bedrock",
"XAI": "xai",
"OPENROUTER": "openrouter",
"OLLAMA": "ollama_chat",
"GROQ": "groq",
"TOGETHER_AI": "together_ai",
"FIREWORKS_AI": "fireworks_ai",
"DEEPSEEK": "openai",
"MISTRAL": "mistral",
"CUSTOM": "custom",
}
VISION_PROVIDER_MAP = NATIVE_PROVIDER_PREFIX
class VisionLLMRouterService:
@ -110,32 +98,11 @@ class VisionLLMRouterService:
if not config.get("model_name") or not config.get("api_key"):
return None
provider = config.get("provider", "").upper()
if config.get("custom_provider"):
provider_prefix = config["custom_provider"]
model_string = f"{provider_prefix}/{config['model_name']}"
else:
provider_prefix = VISION_PROVIDER_MAP.get(provider, provider.lower())
model_string = f"{provider_prefix}/{config['model_name']}"
litellm_params: dict[str, Any] = {
"model": model_string,
"api_key": config.get("api_key"),
}
api_base = resolve_api_base(
provider=provider,
provider_prefix=provider_prefix,
config_api_base=config.get("api_base"),
model_string, resolved_kwargs = to_litellm(
native_connection_from_config(config),
config["model_name"],
)
if api_base:
litellm_params["api_base"] = api_base
if config.get("api_version"):
litellm_params["api_version"] = config["api_version"]
if config.get("litellm_params"):
litellm_params.update(config["litellm_params"])
litellm_params: dict[str, Any] = {"model": model_string, **resolved_kwargs}
deployment: dict[str, Any] = {
"model_name": "auto",