mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-06-12 20:45:20 +02:00
refactor(llm): route calls through resolved models
This commit is contained in:
parent
8b59ca59c1
commit
62ff97c830
5 changed files with 209 additions and 423 deletions
|
|
@ -450,10 +450,10 @@ async def _resolve_agent_billing_for_search_space(
|
|||
thread_id: int | None = None,
|
||||
) -> tuple[UUID, str, str]:
|
||||
"""Resolve ``(owner_user_id, billing_tier, base_model)`` for the search-space
|
||||
agent LLM.
|
||||
chat model.
|
||||
|
||||
Used by Celery tasks (podcast generation, video presentation) to bill the
|
||||
search-space owner's premium credit pool when the agent LLM is premium.
|
||||
search-space owner's premium credit pool when the chat model is premium.
|
||||
|
||||
Resolution rules mirror chat at ``stream_new_chat.py:2294-2351``:
|
||||
|
||||
|
|
|
|||
|
|
@ -20,7 +20,11 @@ from typing import Any
|
|||
from litellm import Router
|
||||
from litellm.utils import ImageResponse
|
||||
|
||||
from app.services.provider_api_base import resolve_api_base
|
||||
from app.services.model_resolver import (
|
||||
NATIVE_PROVIDER_PREFIX,
|
||||
native_connection_from_config,
|
||||
to_litellm,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -30,17 +34,7 @@ IMAGE_GEN_AUTO_MODE_ID = 0
|
|||
# Provider mapping for LiteLLM model string construction.
|
||||
# Only includes providers that support image generation.
|
||||
# See: https://docs.litellm.ai/docs/image_generation#supported-providers
|
||||
IMAGE_GEN_PROVIDER_MAP = {
|
||||
"OPENAI": "openai",
|
||||
"AZURE_OPENAI": "azure",
|
||||
"GOOGLE": "gemini", # Google AI Studio
|
||||
"VERTEX_AI": "vertex_ai",
|
||||
"BEDROCK": "bedrock", # AWS Bedrock
|
||||
"RECRAFT": "recraft",
|
||||
"OPENROUTER": "openrouter",
|
||||
"XINFERENCE": "xinference",
|
||||
"NSCALE": "nscale",
|
||||
}
|
||||
IMAGE_GEN_PROVIDER_MAP = NATIVE_PROVIDER_PREFIX
|
||||
|
||||
|
||||
class ImageGenRouterService:
|
||||
|
|
@ -153,38 +147,11 @@ class ImageGenRouterService:
|
|||
if not config.get("model_name") or not config.get("api_key"):
|
||||
return None
|
||||
|
||||
# Build model string
|
||||
provider = config.get("provider", "").upper()
|
||||
if config.get("custom_provider"):
|
||||
provider_prefix = config["custom_provider"]
|
||||
else:
|
||||
provider_prefix = IMAGE_GEN_PROVIDER_MAP.get(provider, provider.lower())
|
||||
model_string = f"{provider_prefix}/{config['model_name']}"
|
||||
|
||||
# Build litellm params
|
||||
litellm_params: dict[str, Any] = {
|
||||
"model": model_string,
|
||||
"api_key": config.get("api_key"),
|
||||
}
|
||||
|
||||
# Resolve ``api_base`` so deployments don't silently inherit
|
||||
# ``AZURE_OPENAI_ENDPOINT`` / ``OPENAI_API_BASE`` and 404 against
|
||||
# the wrong provider (see ``provider_api_base`` docstring).
|
||||
api_base = resolve_api_base(
|
||||
provider=provider,
|
||||
provider_prefix=provider_prefix,
|
||||
config_api_base=config.get("api_base"),
|
||||
model_string, resolved_kwargs = to_litellm(
|
||||
native_connection_from_config(config),
|
||||
config["model_name"],
|
||||
)
|
||||
if api_base:
|
||||
litellm_params["api_base"] = api_base
|
||||
|
||||
# Add api_version (required for Azure)
|
||||
if config.get("api_version"):
|
||||
litellm_params["api_version"] = config["api_version"]
|
||||
|
||||
# Add any additional litellm parameters
|
||||
if config.get("litellm_params"):
|
||||
litellm_params.update(config["litellm_params"])
|
||||
litellm_params: dict[str, Any] = {"model": model_string, **resolved_kwargs}
|
||||
|
||||
# All configs use same alias "auto" for unified routing
|
||||
deployment: dict[str, Any] = {
|
||||
|
|
|
|||
|
|
@ -30,6 +30,11 @@ from litellm.exceptions import (
|
|||
)
|
||||
from pydantic import Field
|
||||
|
||||
from app.services.model_resolver import (
|
||||
NATIVE_PROVIDER_PREFIX,
|
||||
native_connection_from_config,
|
||||
to_litellm,
|
||||
)
|
||||
from app.utils.perf import get_perf_logger
|
||||
|
||||
litellm.json_logs = False
|
||||
|
|
@ -96,52 +101,8 @@ def _sanitize_content(content: Any) -> Any:
|
|||
# Special ID for Auto mode - uses router for load balancing
|
||||
AUTO_MODE_ID = 0
|
||||
|
||||
# Provider mapping for LiteLLM model string construction
|
||||
PROVIDER_MAP = {
|
||||
"OPENAI": "openai",
|
||||
"ANTHROPIC": "anthropic",
|
||||
"GROQ": "groq",
|
||||
"COHERE": "cohere",
|
||||
"GOOGLE": "gemini",
|
||||
"OLLAMA": "ollama_chat",
|
||||
"MISTRAL": "mistral",
|
||||
"AZURE_OPENAI": "azure",
|
||||
"OPENROUTER": "openrouter",
|
||||
"COMETAPI": "cometapi",
|
||||
"XAI": "xai",
|
||||
"BEDROCK": "bedrock",
|
||||
"AWS_BEDROCK": "bedrock", # Legacy support
|
||||
"VERTEX_AI": "vertex_ai",
|
||||
"TOGETHER_AI": "together_ai",
|
||||
"FIREWORKS_AI": "fireworks_ai",
|
||||
"REPLICATE": "replicate",
|
||||
"PERPLEXITY": "perplexity",
|
||||
"ANYSCALE": "anyscale",
|
||||
"DEEPINFRA": "deepinfra",
|
||||
"CEREBRAS": "cerebras",
|
||||
"SAMBANOVA": "sambanova",
|
||||
"AI21": "ai21",
|
||||
"CLOUDFLARE": "cloudflare",
|
||||
"DATABRICKS": "databricks",
|
||||
"DEEPSEEK": "openai",
|
||||
"ALIBABA_QWEN": "openai",
|
||||
"MOONSHOT": "openai",
|
||||
"ZHIPU": "openai",
|
||||
"GITHUB_MODELS": "github",
|
||||
"HUGGINGFACE": "huggingface",
|
||||
"MINIMAX": "openai",
|
||||
"CUSTOM": "custom",
|
||||
}
|
||||
|
||||
|
||||
# ``PROVIDER_DEFAULT_API_BASE`` and ``PROVIDER_KEY_DEFAULT_API_BASE`` were
|
||||
# hoisted to ``app.services.provider_api_base`` so vision and image-gen
|
||||
# call sites can share the exact same defense (OpenRouter / Groq / etc.
|
||||
# 404-ing against an inherited Azure endpoint). Re-exported here for
|
||||
# backward compatibility with any external import.
|
||||
from app.services.provider_api_base import ( # noqa: E402
|
||||
resolve_api_base,
|
||||
)
|
||||
# Historical export kept for callers that still import PROVIDER_MAP.
|
||||
PROVIDER_MAP = NATIVE_PROVIDER_PREFIX
|
||||
|
||||
|
||||
class LLMRouterService:
|
||||
|
|
@ -420,38 +381,11 @@ class LLMRouterService:
|
|||
if not config.get("model_name") or not config.get("api_key"):
|
||||
return None
|
||||
|
||||
# Build model string
|
||||
provider = config.get("provider", "").upper()
|
||||
if config.get("custom_provider"):
|
||||
provider_prefix = config["custom_provider"]
|
||||
model_string = f"{provider_prefix}/{config['model_name']}"
|
||||
else:
|
||||
provider_prefix = PROVIDER_MAP.get(provider, provider.lower())
|
||||
model_string = f"{provider_prefix}/{config['model_name']}"
|
||||
|
||||
# Build litellm params
|
||||
litellm_params = {
|
||||
"model": model_string,
|
||||
"api_key": config.get("api_key"),
|
||||
}
|
||||
|
||||
# Resolve ``api_base``. Config value wins; otherwise apply a
|
||||
# provider-aware default so the deployment does not silently
|
||||
# inherit unrelated env vars (e.g. ``AZURE_API_BASE``) and route
|
||||
# requests to the wrong endpoint. See ``provider_api_base``
|
||||
# docstring for the motivating bug (OpenRouter models 404-ing
|
||||
# against an Azure endpoint).
|
||||
api_base = resolve_api_base(
|
||||
provider=provider,
|
||||
provider_prefix=provider_prefix,
|
||||
config_api_base=config.get("api_base"),
|
||||
model_string, resolved_kwargs = to_litellm(
|
||||
native_connection_from_config(config),
|
||||
config["model_name"],
|
||||
)
|
||||
if api_base:
|
||||
litellm_params["api_base"] = api_base
|
||||
|
||||
# Add any additional litellm parameters
|
||||
if config.get("litellm_params"):
|
||||
litellm_params.update(config["litellm_params"])
|
||||
litellm_params = {"model": model_string, **resolved_kwargs}
|
||||
|
||||
# Extract rate limits if provided
|
||||
deployment = {
|
||||
|
|
|
|||
|
|
@ -6,9 +6,10 @@ from langchain_core.messages import HumanMessage
|
|||
from langchain_litellm import ChatLiteLLM
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from app.config import config
|
||||
from app.db import NewLLMConfig, SearchSpace
|
||||
from app.db import Model, SearchSpace
|
||||
from app.services.llm_router_service import (
|
||||
AUTO_MODE_ID,
|
||||
ChatLiteLLMRouter,
|
||||
|
|
@ -16,7 +17,7 @@ from app.services.llm_router_service import (
|
|||
get_auto_mode_llm,
|
||||
is_auto_mode,
|
||||
)
|
||||
from app.services.provider_api_base import resolve_api_base
|
||||
from app.services.model_resolver import native_connection_from_config, to_litellm
|
||||
from app.services.token_tracking_service import token_tracker
|
||||
|
||||
# Configure litellm to automatically drop unsupported parameters
|
||||
|
|
@ -66,6 +67,29 @@ def _is_interactive_auth_provider(
|
|||
return False
|
||||
|
||||
|
||||
def _legacy_config_connection(
|
||||
*,
|
||||
provider: str,
|
||||
model_name: str,
|
||||
api_key: str | None,
|
||||
api_base: str | None,
|
||||
custom_provider: str | None = None,
|
||||
litellm_params: dict | None = None,
|
||||
api_version: str | None = None,
|
||||
) -> tuple[str, dict]:
|
||||
cfg = {
|
||||
"provider": provider,
|
||||
"model_name": model_name,
|
||||
"api_key": api_key,
|
||||
"api_base": api_base,
|
||||
"custom_provider": custom_provider,
|
||||
"api_version": api_version,
|
||||
"litellm_params": litellm_params or {},
|
||||
}
|
||||
conn = native_connection_from_config(cfg)
|
||||
return to_litellm(conn, model_name)
|
||||
|
||||
|
||||
class LLMRole:
|
||||
AGENT = "agent" # For agent/chat operations
|
||||
|
||||
|
|
@ -102,6 +126,60 @@ def get_global_llm_config(llm_config_id: int) -> dict | None:
|
|||
return None
|
||||
|
||||
|
||||
def get_global_model(model_id: int) -> dict | None:
|
||||
return next((m for m in config.GLOBAL_MODELS if m.get("id") == model_id), None)
|
||||
|
||||
|
||||
def get_global_connection(connection_id: int) -> dict | None:
|
||||
return next(
|
||||
(c for c in config.GLOBAL_CONNECTIONS if c.get("id") == connection_id),
|
||||
None,
|
||||
)
|
||||
|
||||
|
||||
def _has_capability(model: dict | Model, capability: str) -> bool:
|
||||
caps = (
|
||||
model.get("capabilities", {})
|
||||
if isinstance(model, dict)
|
||||
else model.capabilities or {}
|
||||
)
|
||||
return bool(caps.get(capability))
|
||||
|
||||
|
||||
def _chat_litellm_from_resolved(
|
||||
*,
|
||||
conn: dict | object,
|
||||
model_id: str,
|
||||
disable_streaming: bool = False,
|
||||
) -> tuple[str, dict]:
|
||||
model_string, resolved_kwargs = to_litellm(conn, model_id)
|
||||
litellm_kwargs = {"model": model_string, **resolved_kwargs}
|
||||
if disable_streaming:
|
||||
litellm_kwargs["disable_streaming"] = True
|
||||
return model_string, litellm_kwargs
|
||||
|
||||
|
||||
async def _get_db_model(
|
||||
session: AsyncSession,
|
||||
model_id: int,
|
||||
search_space: SearchSpace,
|
||||
) -> Model | None:
|
||||
result = await session.execute(
|
||||
select(Model)
|
||||
.options(selectinload(Model.connection))
|
||||
.where(Model.id == model_id, Model.enabled.is_(True))
|
||||
)
|
||||
model = result.scalars().first()
|
||||
if not model or not model.connection or not model.connection.enabled:
|
||||
return None
|
||||
conn = model.connection
|
||||
if conn.search_space_id and conn.search_space_id != search_space.id:
|
||||
return None
|
||||
if conn.user_id and conn.user_id != search_space.user_id:
|
||||
return None
|
||||
return model
|
||||
|
||||
|
||||
async def validate_llm_config(
|
||||
provider: str,
|
||||
model_name: str,
|
||||
|
|
@ -146,62 +224,15 @@ async def validate_llm_config(
|
|||
return False, msg
|
||||
|
||||
try:
|
||||
# Build the model string for litellm
|
||||
if custom_provider:
|
||||
model_string = f"{custom_provider}/{model_name}"
|
||||
else:
|
||||
# Map provider enum to litellm format
|
||||
provider_map = {
|
||||
"OPENAI": "openai",
|
||||
"ANTHROPIC": "anthropic",
|
||||
"GROQ": "groq",
|
||||
"COHERE": "cohere",
|
||||
"GOOGLE": "gemini",
|
||||
"OLLAMA": "ollama_chat",
|
||||
"MISTRAL": "mistral",
|
||||
"AZURE_OPENAI": "azure",
|
||||
"OPENROUTER": "openrouter",
|
||||
"COMETAPI": "cometapi",
|
||||
"XAI": "xai",
|
||||
"BEDROCK": "bedrock",
|
||||
"AWS_BEDROCK": "bedrock", # Legacy support (backward compatibility)
|
||||
"VERTEX_AI": "vertex_ai",
|
||||
"TOGETHER_AI": "together_ai",
|
||||
"FIREWORKS_AI": "fireworks_ai",
|
||||
"REPLICATE": "replicate",
|
||||
"PERPLEXITY": "perplexity",
|
||||
"ANYSCALE": "anyscale",
|
||||
"DEEPINFRA": "deepinfra",
|
||||
"CEREBRAS": "cerebras",
|
||||
"SAMBANOVA": "sambanova",
|
||||
"AI21": "ai21",
|
||||
"CLOUDFLARE": "cloudflare",
|
||||
"DATABRICKS": "databricks",
|
||||
# Chinese LLM providers
|
||||
"DEEPSEEK": "openai",
|
||||
"ALIBABA_QWEN": "openai",
|
||||
"MOONSHOT": "openai",
|
||||
"ZHIPU": "openai", # GLM needs special handling
|
||||
"MINIMAX": "openai",
|
||||
"GITHUB_MODELS": "github",
|
||||
}
|
||||
provider_prefix = provider_map.get(provider, provider.lower())
|
||||
model_string = f"{provider_prefix}/{model_name}"
|
||||
|
||||
# Create ChatLiteLLM instance
|
||||
litellm_kwargs = {
|
||||
"model": model_string,
|
||||
"api_key": api_key,
|
||||
"timeout": 30, # Set a timeout for validation
|
||||
}
|
||||
|
||||
# Add optional parameters
|
||||
if api_base:
|
||||
litellm_kwargs["api_base"] = api_base
|
||||
|
||||
# Add any additional litellm parameters
|
||||
if litellm_params:
|
||||
litellm_kwargs.update(litellm_params)
|
||||
model_string, resolved_kwargs = _legacy_config_connection(
|
||||
provider=provider,
|
||||
model_name=model_name,
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
custom_provider=custom_provider,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
litellm_kwargs = {"model": model_string, **resolved_kwargs, "timeout": 30}
|
||||
|
||||
from app.agents.chat.runtime.llm_config import (
|
||||
SanitizedChatLiteLLM,
|
||||
|
|
@ -283,9 +314,9 @@ async def get_search_space_llm_instance(
|
|||
logger.error(f"Search space {search_space_id} not found")
|
||||
return None
|
||||
|
||||
# Get the appropriate LLM config ID based on role
|
||||
# Get the appropriate model binding ID based on role
|
||||
if role == LLMRole.AGENT:
|
||||
llm_config_id = search_space.agent_llm_id
|
||||
llm_config_id = search_space.chat_model_id
|
||||
else:
|
||||
logger.error(f"Invalid LLM role: {role}")
|
||||
return None
|
||||
|
|
@ -312,70 +343,26 @@ async def get_search_space_llm_instance(
|
|||
logger.error(f"Failed to create ChatLiteLLMRouter: {e}")
|
||||
return None
|
||||
|
||||
# Check if this is a global config (negative ID)
|
||||
# Check if this is a global virtual model (negative ID)
|
||||
if llm_config_id < 0:
|
||||
global_config = get_global_llm_config(llm_config_id)
|
||||
if not global_config:
|
||||
logger.error(f"Global LLM config {llm_config_id} not found")
|
||||
global_model = get_global_model(llm_config_id)
|
||||
if not global_model or not _has_capability(global_model, "chat"):
|
||||
logger.error(f"Global chat model {llm_config_id} not found")
|
||||
return None
|
||||
global_connection = get_global_connection(global_model["connection_id"])
|
||||
if not global_connection:
|
||||
logger.error(
|
||||
"Global connection %s not found for model %s",
|
||||
global_model["connection_id"],
|
||||
llm_config_id,
|
||||
)
|
||||
return None
|
||||
|
||||
# Build model string for global config
|
||||
if global_config.get("custom_provider"):
|
||||
model_string = (
|
||||
f"{global_config['custom_provider']}/{global_config['model_name']}"
|
||||
)
|
||||
else:
|
||||
provider_map = {
|
||||
"OPENAI": "openai",
|
||||
"ANTHROPIC": "anthropic",
|
||||
"GROQ": "groq",
|
||||
"COHERE": "cohere",
|
||||
"GOOGLE": "gemini",
|
||||
"OLLAMA": "ollama_chat",
|
||||
"MISTRAL": "mistral",
|
||||
"AZURE_OPENAI": "azure",
|
||||
"OPENROUTER": "openrouter",
|
||||
"COMETAPI": "cometapi",
|
||||
"XAI": "xai",
|
||||
"BEDROCK": "bedrock",
|
||||
"AWS_BEDROCK": "bedrock",
|
||||
"VERTEX_AI": "vertex_ai",
|
||||
"TOGETHER_AI": "together_ai",
|
||||
"FIREWORKS_AI": "fireworks_ai",
|
||||
"REPLICATE": "replicate",
|
||||
"PERPLEXITY": "perplexity",
|
||||
"ANYSCALE": "anyscale",
|
||||
"DEEPINFRA": "deepinfra",
|
||||
"CEREBRAS": "cerebras",
|
||||
"SAMBANOVA": "sambanova",
|
||||
"AI21": "ai21",
|
||||
"CLOUDFLARE": "cloudflare",
|
||||
"DATABRICKS": "databricks",
|
||||
"DEEPSEEK": "openai",
|
||||
"ALIBABA_QWEN": "openai",
|
||||
"MOONSHOT": "openai",
|
||||
"ZHIPU": "openai",
|
||||
"MINIMAX": "openai",
|
||||
}
|
||||
provider_prefix = provider_map.get(
|
||||
global_config["provider"], global_config["provider"].lower()
|
||||
)
|
||||
model_string = f"{provider_prefix}/{global_config['model_name']}"
|
||||
|
||||
# Create ChatLiteLLM instance from global config
|
||||
litellm_kwargs = {
|
||||
"model": model_string,
|
||||
"api_key": global_config["api_key"],
|
||||
}
|
||||
|
||||
if global_config.get("api_base"):
|
||||
litellm_kwargs["api_base"] = global_config["api_base"]
|
||||
|
||||
if global_config.get("litellm_params"):
|
||||
litellm_kwargs.update(global_config["litellm_params"])
|
||||
|
||||
if disable_streaming:
|
||||
litellm_kwargs["disable_streaming"] = True
|
||||
_, litellm_kwargs = _chat_litellm_from_resolved(
|
||||
conn=global_connection,
|
||||
model_id=global_model["model_id"],
|
||||
disable_streaming=disable_streaming,
|
||||
)
|
||||
|
||||
from app.agents.chat.runtime.llm_config import (
|
||||
SanitizedChatLiteLLM,
|
||||
|
|
@ -383,80 +370,18 @@ async def get_search_space_llm_instance(
|
|||
|
||||
return SanitizedChatLiteLLM(**litellm_kwargs)
|
||||
|
||||
# Get the LLM configuration from database (NewLLMConfig)
|
||||
result = await session.execute(
|
||||
select(NewLLMConfig).where(
|
||||
NewLLMConfig.id == llm_config_id,
|
||||
NewLLMConfig.search_space_id == search_space_id,
|
||||
)
|
||||
)
|
||||
llm_config = result.scalars().first()
|
||||
|
||||
if not llm_config:
|
||||
model = await _get_db_model(session, llm_config_id, search_space)
|
||||
if not model or not _has_capability(model, "chat"):
|
||||
logger.error(
|
||||
f"LLM config {llm_config_id} not found in search space {search_space_id}"
|
||||
f"Chat model {llm_config_id} not found in search space {search_space_id}"
|
||||
)
|
||||
return None
|
||||
|
||||
# Build the model string for litellm
|
||||
if llm_config.custom_provider:
|
||||
model_string = f"{llm_config.custom_provider}/{llm_config.model_name}"
|
||||
else:
|
||||
# Map provider enum to litellm format
|
||||
provider_map = {
|
||||
"OPENAI": "openai",
|
||||
"ANTHROPIC": "anthropic",
|
||||
"GROQ": "groq",
|
||||
"COHERE": "cohere",
|
||||
"GOOGLE": "gemini",
|
||||
"OLLAMA": "ollama_chat",
|
||||
"MISTRAL": "mistral",
|
||||
"AZURE_OPENAI": "azure",
|
||||
"OPENROUTER": "openrouter",
|
||||
"COMETAPI": "cometapi",
|
||||
"XAI": "xai",
|
||||
"BEDROCK": "bedrock",
|
||||
"AWS_BEDROCK": "bedrock",
|
||||
"VERTEX_AI": "vertex_ai",
|
||||
"TOGETHER_AI": "together_ai",
|
||||
"FIREWORKS_AI": "fireworks_ai",
|
||||
"REPLICATE": "replicate",
|
||||
"PERPLEXITY": "perplexity",
|
||||
"ANYSCALE": "anyscale",
|
||||
"DEEPINFRA": "deepinfra",
|
||||
"CEREBRAS": "cerebras",
|
||||
"SAMBANOVA": "sambanova",
|
||||
"AI21": "ai21",
|
||||
"CLOUDFLARE": "cloudflare",
|
||||
"DATABRICKS": "databricks",
|
||||
"DEEPSEEK": "openai",
|
||||
"ALIBABA_QWEN": "openai",
|
||||
"MOONSHOT": "openai",
|
||||
"ZHIPU": "openai",
|
||||
"MINIMAX": "openai",
|
||||
"GITHUB_MODELS": "github",
|
||||
}
|
||||
provider_prefix = provider_map.get(
|
||||
llm_config.provider.value, llm_config.provider.value.lower()
|
||||
)
|
||||
model_string = f"{provider_prefix}/{llm_config.model_name}"
|
||||
|
||||
# Create ChatLiteLLM instance
|
||||
litellm_kwargs = {
|
||||
"model": model_string,
|
||||
"api_key": llm_config.api_key,
|
||||
}
|
||||
|
||||
# Add optional parameters
|
||||
if llm_config.api_base:
|
||||
litellm_kwargs["api_base"] = llm_config.api_base
|
||||
|
||||
# Add any additional litellm parameters
|
||||
if llm_config.litellm_params:
|
||||
litellm_kwargs.update(llm_config.litellm_params)
|
||||
|
||||
if disable_streaming:
|
||||
litellm_kwargs["disable_streaming"] = True
|
||||
_, litellm_kwargs = _chat_litellm_from_resolved(
|
||||
conn=model.connection,
|
||||
model_id=model.model_id,
|
||||
disable_streaming=disable_streaming,
|
||||
)
|
||||
|
||||
from app.agents.chat.runtime.llm_config import (
|
||||
SanitizedChatLiteLLM,
|
||||
|
|
@ -474,7 +399,7 @@ async def get_search_space_llm_instance(
|
|||
async def get_agent_llm(
|
||||
session: AsyncSession, search_space_id: int, disable_streaming: bool = False
|
||||
) -> ChatLiteLLM | ChatLiteLLMRouter | None:
|
||||
"""Get the search space's agent LLM instance for chat operations."""
|
||||
"""Get the search space's chat model instance."""
|
||||
return await get_search_space_llm_instance(
|
||||
session,
|
||||
search_space_id,
|
||||
|
|
@ -488,22 +413,19 @@ async def get_vision_llm(
|
|||
) -> ChatLiteLLM | ChatLiteLLMRouter | None:
|
||||
"""Get the search space's vision LLM instance for screenshot analysis.
|
||||
|
||||
Resolves from the dedicated VisionLLMConfig system:
|
||||
Resolves from the new connection/model role bindings:
|
||||
- Auto mode (ID 0): VisionLLMRouterService
|
||||
- Global (negative ID): YAML configs
|
||||
- DB (positive ID): VisionLLMConfig table
|
||||
- Global (negative ID): virtual GLOBAL models from YAML
|
||||
- DB (positive ID): Model + Connection tables
|
||||
|
||||
Premium global configs are wrapped in :class:`QuotaCheckedVisionLLM`
|
||||
so each ``ainvoke`` debits the search-space owner's premium credit
|
||||
pool. User-owned BYOK configs and free global configs are returned
|
||||
unwrapped — they don't consume premium credit (issue M).
|
||||
"""
|
||||
from app.db import VisionLLMConfig
|
||||
from app.services.quota_checked_vision_llm import QuotaCheckedVisionLLM
|
||||
from app.services.vision_llm_router_service import (
|
||||
VISION_PROVIDER_MAP,
|
||||
VisionLLMRouterService,
|
||||
get_global_vision_llm_config,
|
||||
is_vision_auto_mode,
|
||||
)
|
||||
|
||||
|
|
@ -516,13 +438,43 @@ async def get_vision_llm(
|
|||
logger.error(f"Search space {search_space_id} not found")
|
||||
return None
|
||||
|
||||
config_id = search_space.vision_llm_config_id
|
||||
owner_user_id = search_space.user_id
|
||||
|
||||
# Prefer the selected chat model when it is vision-capable.
|
||||
chat_model_id = search_space.chat_model_id
|
||||
if chat_model_id and chat_model_id != AUTO_MODE_ID:
|
||||
if chat_model_id < 0:
|
||||
chat_model = get_global_model(chat_model_id)
|
||||
if chat_model and _has_capability(chat_model, "vision"):
|
||||
global_connection = get_global_connection(chat_model["connection_id"])
|
||||
if global_connection:
|
||||
model_string, litellm_kwargs = _chat_litellm_from_resolved(
|
||||
conn=global_connection,
|
||||
model_id=chat_model["model_id"],
|
||||
)
|
||||
from app.agents.chat.runtime.llm_config import (
|
||||
SanitizedChatLiteLLM,
|
||||
)
|
||||
|
||||
return SanitizedChatLiteLLM(**litellm_kwargs)
|
||||
else:
|
||||
chat_model = await _get_db_model(session, chat_model_id, search_space)
|
||||
if chat_model and _has_capability(chat_model, "vision"):
|
||||
_, litellm_kwargs = _chat_litellm_from_resolved(
|
||||
conn=chat_model.connection,
|
||||
model_id=chat_model.model_id,
|
||||
)
|
||||
from app.agents.chat.runtime.llm_config import (
|
||||
SanitizedChatLiteLLM,
|
||||
)
|
||||
|
||||
return SanitizedChatLiteLLM(**litellm_kwargs)
|
||||
|
||||
config_id = search_space.vision_model_id
|
||||
if config_id is None:
|
||||
logger.error(f"No vision LLM configured for search space {search_space_id}")
|
||||
return None
|
||||
|
||||
owner_user_id = search_space.user_id
|
||||
|
||||
if is_vision_auto_mode(config_id):
|
||||
if not VisionLLMRouterService.is_initialized():
|
||||
logger.error(
|
||||
|
|
@ -546,34 +498,24 @@ async def get_vision_llm(
|
|||
return None
|
||||
|
||||
if config_id < 0:
|
||||
global_cfg = get_global_vision_llm_config(config_id)
|
||||
if not global_cfg:
|
||||
logger.error(f"Global vision LLM config {config_id} not found")
|
||||
global_model = get_global_model(config_id)
|
||||
if not global_model or not _has_capability(global_model, "vision"):
|
||||
logger.error(f"Global vision model {config_id} not found")
|
||||
return None
|
||||
|
||||
if global_cfg.get("custom_provider"):
|
||||
provider_prefix = global_cfg["custom_provider"]
|
||||
model_string = f"{provider_prefix}/{global_cfg['model_name']}"
|
||||
else:
|
||||
provider_prefix = VISION_PROVIDER_MAP.get(
|
||||
global_cfg["provider"].upper(),
|
||||
global_cfg["provider"].lower(),
|
||||
global_connection = get_global_connection(global_model["connection_id"])
|
||||
if not global_connection:
|
||||
logger.error(
|
||||
"Global connection %s not found for model %s",
|
||||
global_model["connection_id"],
|
||||
config_id,
|
||||
)
|
||||
model_string = f"{provider_prefix}/{global_cfg['model_name']}"
|
||||
return None
|
||||
|
||||
litellm_kwargs = {
|
||||
"model": model_string,
|
||||
"api_key": global_cfg["api_key"],
|
||||
}
|
||||
api_base = resolve_api_base(
|
||||
provider=global_cfg.get("provider"),
|
||||
provider_prefix=provider_prefix,
|
||||
config_api_base=global_cfg.get("api_base"),
|
||||
model_string, litellm_kwargs = _chat_litellm_from_resolved(
|
||||
conn=global_connection,
|
||||
model_id=global_model["model_id"],
|
||||
)
|
||||
if api_base:
|
||||
litellm_kwargs["api_base"] = api_base
|
||||
if global_cfg.get("litellm_params"):
|
||||
litellm_kwargs.update(global_cfg["litellm_params"])
|
||||
|
||||
from app.agents.chat.runtime.llm_config import (
|
||||
SanitizedChatLiteLLM,
|
||||
|
|
@ -581,7 +523,7 @@ async def get_vision_llm(
|
|||
|
||||
inner_llm = SanitizedChatLiteLLM(**litellm_kwargs)
|
||||
|
||||
billing_tier = str(global_cfg.get("billing_tier", "free")).lower()
|
||||
billing_tier = str(global_model.get("billing_tier", "free")).lower()
|
||||
if billing_tier == "premium":
|
||||
return QuotaCheckedVisionLLM(
|
||||
inner_llm,
|
||||
|
|
@ -589,47 +531,23 @@ async def get_vision_llm(
|
|||
search_space_id=search_space_id,
|
||||
billing_tier=billing_tier,
|
||||
base_model=model_string,
|
||||
quota_reserve_tokens=global_cfg.get("quota_reserve_tokens"),
|
||||
quota_reserve_tokens=global_model.get("catalog", {}).get(
|
||||
"quota_reserve_tokens"
|
||||
),
|
||||
)
|
||||
return inner_llm
|
||||
|
||||
# User-owned (positive ID) BYOK configs — always free.
|
||||
result = await session.execute(
|
||||
select(VisionLLMConfig).where(
|
||||
VisionLLMConfig.id == config_id,
|
||||
VisionLLMConfig.search_space_id == search_space_id,
|
||||
)
|
||||
)
|
||||
vision_cfg = result.scalars().first()
|
||||
if not vision_cfg:
|
||||
model = await _get_db_model(session, config_id, search_space)
|
||||
if not model or not _has_capability(model, "vision"):
|
||||
logger.error(
|
||||
f"Vision LLM config {config_id} not found in search space {search_space_id}"
|
||||
f"Vision model {config_id} not found in search space {search_space_id}"
|
||||
)
|
||||
return None
|
||||
|
||||
if vision_cfg.custom_provider:
|
||||
provider_prefix = vision_cfg.custom_provider
|
||||
model_string = f"{provider_prefix}/{vision_cfg.model_name}"
|
||||
else:
|
||||
provider_prefix = VISION_PROVIDER_MAP.get(
|
||||
vision_cfg.provider.value.upper(),
|
||||
vision_cfg.provider.value.lower(),
|
||||
)
|
||||
model_string = f"{provider_prefix}/{vision_cfg.model_name}"
|
||||
|
||||
litellm_kwargs = {
|
||||
"model": model_string,
|
||||
"api_key": vision_cfg.api_key,
|
||||
}
|
||||
api_base = resolve_api_base(
|
||||
provider=vision_cfg.provider.value,
|
||||
provider_prefix=provider_prefix,
|
||||
config_api_base=vision_cfg.api_base,
|
||||
_, litellm_kwargs = _chat_litellm_from_resolved(
|
||||
conn=model.connection,
|
||||
model_id=model.model_id,
|
||||
)
|
||||
if api_base:
|
||||
litellm_kwargs["api_base"] = api_base
|
||||
if vision_cfg.litellm_params:
|
||||
litellm_kwargs.update(vision_cfg.litellm_params)
|
||||
|
||||
from app.agents.chat.runtime.llm_config import (
|
||||
SanitizedChatLiteLLM,
|
||||
|
|
|
|||
|
|
@ -3,29 +3,17 @@ from typing import Any
|
|||
|
||||
from litellm import Router
|
||||
|
||||
from app.services.provider_api_base import resolve_api_base
|
||||
from app.services.model_resolver import (
|
||||
NATIVE_PROVIDER_PREFIX,
|
||||
native_connection_from_config,
|
||||
to_litellm,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
VISION_AUTO_MODE_ID = 0
|
||||
|
||||
VISION_PROVIDER_MAP = {
|
||||
"OPENAI": "openai",
|
||||
"ANTHROPIC": "anthropic",
|
||||
"GOOGLE": "gemini",
|
||||
"AZURE_OPENAI": "azure",
|
||||
"VERTEX_AI": "vertex_ai",
|
||||
"BEDROCK": "bedrock",
|
||||
"XAI": "xai",
|
||||
"OPENROUTER": "openrouter",
|
||||
"OLLAMA": "ollama_chat",
|
||||
"GROQ": "groq",
|
||||
"TOGETHER_AI": "together_ai",
|
||||
"FIREWORKS_AI": "fireworks_ai",
|
||||
"DEEPSEEK": "openai",
|
||||
"MISTRAL": "mistral",
|
||||
"CUSTOM": "custom",
|
||||
}
|
||||
VISION_PROVIDER_MAP = NATIVE_PROVIDER_PREFIX
|
||||
|
||||
|
||||
class VisionLLMRouterService:
|
||||
|
|
@ -110,32 +98,11 @@ class VisionLLMRouterService:
|
|||
if not config.get("model_name") or not config.get("api_key"):
|
||||
return None
|
||||
|
||||
provider = config.get("provider", "").upper()
|
||||
if config.get("custom_provider"):
|
||||
provider_prefix = config["custom_provider"]
|
||||
model_string = f"{provider_prefix}/{config['model_name']}"
|
||||
else:
|
||||
provider_prefix = VISION_PROVIDER_MAP.get(provider, provider.lower())
|
||||
model_string = f"{provider_prefix}/{config['model_name']}"
|
||||
|
||||
litellm_params: dict[str, Any] = {
|
||||
"model": model_string,
|
||||
"api_key": config.get("api_key"),
|
||||
}
|
||||
|
||||
api_base = resolve_api_base(
|
||||
provider=provider,
|
||||
provider_prefix=provider_prefix,
|
||||
config_api_base=config.get("api_base"),
|
||||
model_string, resolved_kwargs = to_litellm(
|
||||
native_connection_from_config(config),
|
||||
config["model_name"],
|
||||
)
|
||||
if api_base:
|
||||
litellm_params["api_base"] = api_base
|
||||
|
||||
if config.get("api_version"):
|
||||
litellm_params["api_version"] = config["api_version"]
|
||||
|
||||
if config.get("litellm_params"):
|
||||
litellm_params.update(config["litellm_params"])
|
||||
litellm_params: dict[str, Any] = {"model": model_string, **resolved_kwargs}
|
||||
|
||||
deployment: dict[str, Any] = {
|
||||
"model_name": "auto",
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue