From 62ff97c830f167c3810a7f50643f6cb779f98aad Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Wed, 10 Jun 2026 21:48:23 +0530 Subject: [PATCH] refactor(llm): route calls through resolved models --- .../app/services/billable_calls.py | 4 +- .../app/services/image_gen_router_service.py | 53 +-- .../app/services/llm_router_service.py | 88 +--- surfsense_backend/app/services/llm_service.py | 434 +++++++----------- .../app/services/vision_llm_router_service.py | 53 +-- 5 files changed, 209 insertions(+), 423 deletions(-) diff --git a/surfsense_backend/app/services/billable_calls.py b/surfsense_backend/app/services/billable_calls.py index 92ccd6a78..356195f6a 100644 --- a/surfsense_backend/app/services/billable_calls.py +++ b/surfsense_backend/app/services/billable_calls.py @@ -450,10 +450,10 @@ async def _resolve_agent_billing_for_search_space( thread_id: int | None = None, ) -> tuple[UUID, str, str]: """Resolve ``(owner_user_id, billing_tier, base_model)`` for the search-space - agent LLM. + chat model. Used by Celery tasks (podcast generation, video presentation) to bill the - search-space owner's premium credit pool when the agent LLM is premium. + search-space owner's premium credit pool when the chat model is premium. Resolution rules mirror chat at ``stream_new_chat.py:2294-2351``: diff --git a/surfsense_backend/app/services/image_gen_router_service.py b/surfsense_backend/app/services/image_gen_router_service.py index b4de2a0bf..0b03f5c6d 100644 --- a/surfsense_backend/app/services/image_gen_router_service.py +++ b/surfsense_backend/app/services/image_gen_router_service.py @@ -20,7 +20,11 @@ from typing import Any from litellm import Router from litellm.utils import ImageResponse -from app.services.provider_api_base import resolve_api_base +from app.services.model_resolver import ( + NATIVE_PROVIDER_PREFIX, + native_connection_from_config, + to_litellm, +) logger = logging.getLogger(__name__) @@ -30,17 +34,7 @@ IMAGE_GEN_AUTO_MODE_ID = 0 # Provider mapping for LiteLLM model string construction. # Only includes providers that support image generation. # See: https://docs.litellm.ai/docs/image_generation#supported-providers -IMAGE_GEN_PROVIDER_MAP = { - "OPENAI": "openai", - "AZURE_OPENAI": "azure", - "GOOGLE": "gemini", # Google AI Studio - "VERTEX_AI": "vertex_ai", - "BEDROCK": "bedrock", # AWS Bedrock - "RECRAFT": "recraft", - "OPENROUTER": "openrouter", - "XINFERENCE": "xinference", - "NSCALE": "nscale", -} +IMAGE_GEN_PROVIDER_MAP = NATIVE_PROVIDER_PREFIX class ImageGenRouterService: @@ -153,38 +147,11 @@ class ImageGenRouterService: if not config.get("model_name") or not config.get("api_key"): return None - # Build model string - provider = config.get("provider", "").upper() - if config.get("custom_provider"): - provider_prefix = config["custom_provider"] - else: - provider_prefix = IMAGE_GEN_PROVIDER_MAP.get(provider, provider.lower()) - model_string = f"{provider_prefix}/{config['model_name']}" - - # Build litellm params - litellm_params: dict[str, Any] = { - "model": model_string, - "api_key": config.get("api_key"), - } - - # Resolve ``api_base`` so deployments don't silently inherit - # ``AZURE_OPENAI_ENDPOINT`` / ``OPENAI_API_BASE`` and 404 against - # the wrong provider (see ``provider_api_base`` docstring). - api_base = resolve_api_base( - provider=provider, - provider_prefix=provider_prefix, - config_api_base=config.get("api_base"), + model_string, resolved_kwargs = to_litellm( + native_connection_from_config(config), + config["model_name"], ) - if api_base: - litellm_params["api_base"] = api_base - - # Add api_version (required for Azure) - if config.get("api_version"): - litellm_params["api_version"] = config["api_version"] - - # Add any additional litellm parameters - if config.get("litellm_params"): - litellm_params.update(config["litellm_params"]) + litellm_params: dict[str, Any] = {"model": model_string, **resolved_kwargs} # All configs use same alias "auto" for unified routing deployment: dict[str, Any] = { diff --git a/surfsense_backend/app/services/llm_router_service.py b/surfsense_backend/app/services/llm_router_service.py index d220aa346..69feb30eb 100644 --- a/surfsense_backend/app/services/llm_router_service.py +++ b/surfsense_backend/app/services/llm_router_service.py @@ -30,6 +30,11 @@ from litellm.exceptions import ( ) from pydantic import Field +from app.services.model_resolver import ( + NATIVE_PROVIDER_PREFIX, + native_connection_from_config, + to_litellm, +) from app.utils.perf import get_perf_logger litellm.json_logs = False @@ -96,52 +101,8 @@ def _sanitize_content(content: Any) -> Any: # Special ID for Auto mode - uses router for load balancing AUTO_MODE_ID = 0 -# Provider mapping for LiteLLM model string construction -PROVIDER_MAP = { - "OPENAI": "openai", - "ANTHROPIC": "anthropic", - "GROQ": "groq", - "COHERE": "cohere", - "GOOGLE": "gemini", - "OLLAMA": "ollama_chat", - "MISTRAL": "mistral", - "AZURE_OPENAI": "azure", - "OPENROUTER": "openrouter", - "COMETAPI": "cometapi", - "XAI": "xai", - "BEDROCK": "bedrock", - "AWS_BEDROCK": "bedrock", # Legacy support - "VERTEX_AI": "vertex_ai", - "TOGETHER_AI": "together_ai", - "FIREWORKS_AI": "fireworks_ai", - "REPLICATE": "replicate", - "PERPLEXITY": "perplexity", - "ANYSCALE": "anyscale", - "DEEPINFRA": "deepinfra", - "CEREBRAS": "cerebras", - "SAMBANOVA": "sambanova", - "AI21": "ai21", - "CLOUDFLARE": "cloudflare", - "DATABRICKS": "databricks", - "DEEPSEEK": "openai", - "ALIBABA_QWEN": "openai", - "MOONSHOT": "openai", - "ZHIPU": "openai", - "GITHUB_MODELS": "github", - "HUGGINGFACE": "huggingface", - "MINIMAX": "openai", - "CUSTOM": "custom", -} - - -# ``PROVIDER_DEFAULT_API_BASE`` and ``PROVIDER_KEY_DEFAULT_API_BASE`` were -# hoisted to ``app.services.provider_api_base`` so vision and image-gen -# call sites can share the exact same defense (OpenRouter / Groq / etc. -# 404-ing against an inherited Azure endpoint). Re-exported here for -# backward compatibility with any external import. -from app.services.provider_api_base import ( # noqa: E402 - resolve_api_base, -) +# Historical export kept for callers that still import PROVIDER_MAP. +PROVIDER_MAP = NATIVE_PROVIDER_PREFIX class LLMRouterService: @@ -420,38 +381,11 @@ class LLMRouterService: if not config.get("model_name") or not config.get("api_key"): return None - # Build model string - provider = config.get("provider", "").upper() - if config.get("custom_provider"): - provider_prefix = config["custom_provider"] - model_string = f"{provider_prefix}/{config['model_name']}" - else: - provider_prefix = PROVIDER_MAP.get(provider, provider.lower()) - model_string = f"{provider_prefix}/{config['model_name']}" - - # Build litellm params - litellm_params = { - "model": model_string, - "api_key": config.get("api_key"), - } - - # Resolve ``api_base``. Config value wins; otherwise apply a - # provider-aware default so the deployment does not silently - # inherit unrelated env vars (e.g. ``AZURE_API_BASE``) and route - # requests to the wrong endpoint. See ``provider_api_base`` - # docstring for the motivating bug (OpenRouter models 404-ing - # against an Azure endpoint). - api_base = resolve_api_base( - provider=provider, - provider_prefix=provider_prefix, - config_api_base=config.get("api_base"), + model_string, resolved_kwargs = to_litellm( + native_connection_from_config(config), + config["model_name"], ) - if api_base: - litellm_params["api_base"] = api_base - - # Add any additional litellm parameters - if config.get("litellm_params"): - litellm_params.update(config["litellm_params"]) + litellm_params = {"model": model_string, **resolved_kwargs} # Extract rate limits if provided deployment = { diff --git a/surfsense_backend/app/services/llm_service.py b/surfsense_backend/app/services/llm_service.py index 7061a826f..75451d01f 100644 --- a/surfsense_backend/app/services/llm_service.py +++ b/surfsense_backend/app/services/llm_service.py @@ -6,9 +6,10 @@ from langchain_core.messages import HumanMessage from langchain_litellm import ChatLiteLLM from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.future import select +from sqlalchemy.orm import selectinload from app.config import config -from app.db import NewLLMConfig, SearchSpace +from app.db import Model, SearchSpace from app.services.llm_router_service import ( AUTO_MODE_ID, ChatLiteLLMRouter, @@ -16,7 +17,7 @@ from app.services.llm_router_service import ( get_auto_mode_llm, is_auto_mode, ) -from app.services.provider_api_base import resolve_api_base +from app.services.model_resolver import native_connection_from_config, to_litellm from app.services.token_tracking_service import token_tracker # Configure litellm to automatically drop unsupported parameters @@ -66,6 +67,29 @@ def _is_interactive_auth_provider( return False +def _legacy_config_connection( + *, + provider: str, + model_name: str, + api_key: str | None, + api_base: str | None, + custom_provider: str | None = None, + litellm_params: dict | None = None, + api_version: str | None = None, +) -> tuple[str, dict]: + cfg = { + "provider": provider, + "model_name": model_name, + "api_key": api_key, + "api_base": api_base, + "custom_provider": custom_provider, + "api_version": api_version, + "litellm_params": litellm_params or {}, + } + conn = native_connection_from_config(cfg) + return to_litellm(conn, model_name) + + class LLMRole: AGENT = "agent" # For agent/chat operations @@ -102,6 +126,60 @@ def get_global_llm_config(llm_config_id: int) -> dict | None: return None +def get_global_model(model_id: int) -> dict | None: + return next((m for m in config.GLOBAL_MODELS if m.get("id") == model_id), None) + + +def get_global_connection(connection_id: int) -> dict | None: + return next( + (c for c in config.GLOBAL_CONNECTIONS if c.get("id") == connection_id), + None, + ) + + +def _has_capability(model: dict | Model, capability: str) -> bool: + caps = ( + model.get("capabilities", {}) + if isinstance(model, dict) + else model.capabilities or {} + ) + return bool(caps.get(capability)) + + +def _chat_litellm_from_resolved( + *, + conn: dict | object, + model_id: str, + disable_streaming: bool = False, +) -> tuple[str, dict]: + model_string, resolved_kwargs = to_litellm(conn, model_id) + litellm_kwargs = {"model": model_string, **resolved_kwargs} + if disable_streaming: + litellm_kwargs["disable_streaming"] = True + return model_string, litellm_kwargs + + +async def _get_db_model( + session: AsyncSession, + model_id: int, + search_space: SearchSpace, +) -> Model | None: + result = await session.execute( + select(Model) + .options(selectinload(Model.connection)) + .where(Model.id == model_id, Model.enabled.is_(True)) + ) + model = result.scalars().first() + if not model or not model.connection or not model.connection.enabled: + return None + conn = model.connection + if conn.search_space_id and conn.search_space_id != search_space.id: + return None + if conn.user_id and conn.user_id != search_space.user_id: + return None + return model + + async def validate_llm_config( provider: str, model_name: str, @@ -146,62 +224,15 @@ async def validate_llm_config( return False, msg try: - # Build the model string for litellm - if custom_provider: - model_string = f"{custom_provider}/{model_name}" - else: - # Map provider enum to litellm format - provider_map = { - "OPENAI": "openai", - "ANTHROPIC": "anthropic", - "GROQ": "groq", - "COHERE": "cohere", - "GOOGLE": "gemini", - "OLLAMA": "ollama_chat", - "MISTRAL": "mistral", - "AZURE_OPENAI": "azure", - "OPENROUTER": "openrouter", - "COMETAPI": "cometapi", - "XAI": "xai", - "BEDROCK": "bedrock", - "AWS_BEDROCK": "bedrock", # Legacy support (backward compatibility) - "VERTEX_AI": "vertex_ai", - "TOGETHER_AI": "together_ai", - "FIREWORKS_AI": "fireworks_ai", - "REPLICATE": "replicate", - "PERPLEXITY": "perplexity", - "ANYSCALE": "anyscale", - "DEEPINFRA": "deepinfra", - "CEREBRAS": "cerebras", - "SAMBANOVA": "sambanova", - "AI21": "ai21", - "CLOUDFLARE": "cloudflare", - "DATABRICKS": "databricks", - # Chinese LLM providers - "DEEPSEEK": "openai", - "ALIBABA_QWEN": "openai", - "MOONSHOT": "openai", - "ZHIPU": "openai", # GLM needs special handling - "MINIMAX": "openai", - "GITHUB_MODELS": "github", - } - provider_prefix = provider_map.get(provider, provider.lower()) - model_string = f"{provider_prefix}/{model_name}" - - # Create ChatLiteLLM instance - litellm_kwargs = { - "model": model_string, - "api_key": api_key, - "timeout": 30, # Set a timeout for validation - } - - # Add optional parameters - if api_base: - litellm_kwargs["api_base"] = api_base - - # Add any additional litellm parameters - if litellm_params: - litellm_kwargs.update(litellm_params) + model_string, resolved_kwargs = _legacy_config_connection( + provider=provider, + model_name=model_name, + api_key=api_key, + api_base=api_base, + custom_provider=custom_provider, + litellm_params=litellm_params, + ) + litellm_kwargs = {"model": model_string, **resolved_kwargs, "timeout": 30} from app.agents.chat.runtime.llm_config import ( SanitizedChatLiteLLM, @@ -283,9 +314,9 @@ async def get_search_space_llm_instance( logger.error(f"Search space {search_space_id} not found") return None - # Get the appropriate LLM config ID based on role + # Get the appropriate model binding ID based on role if role == LLMRole.AGENT: - llm_config_id = search_space.agent_llm_id + llm_config_id = search_space.chat_model_id else: logger.error(f"Invalid LLM role: {role}") return None @@ -312,70 +343,26 @@ async def get_search_space_llm_instance( logger.error(f"Failed to create ChatLiteLLMRouter: {e}") return None - # Check if this is a global config (negative ID) + # Check if this is a global virtual model (negative ID) if llm_config_id < 0: - global_config = get_global_llm_config(llm_config_id) - if not global_config: - logger.error(f"Global LLM config {llm_config_id} not found") + global_model = get_global_model(llm_config_id) + if not global_model or not _has_capability(global_model, "chat"): + logger.error(f"Global chat model {llm_config_id} not found") + return None + global_connection = get_global_connection(global_model["connection_id"]) + if not global_connection: + logger.error( + "Global connection %s not found for model %s", + global_model["connection_id"], + llm_config_id, + ) return None - # Build model string for global config - if global_config.get("custom_provider"): - model_string = ( - f"{global_config['custom_provider']}/{global_config['model_name']}" - ) - else: - provider_map = { - "OPENAI": "openai", - "ANTHROPIC": "anthropic", - "GROQ": "groq", - "COHERE": "cohere", - "GOOGLE": "gemini", - "OLLAMA": "ollama_chat", - "MISTRAL": "mistral", - "AZURE_OPENAI": "azure", - "OPENROUTER": "openrouter", - "COMETAPI": "cometapi", - "XAI": "xai", - "BEDROCK": "bedrock", - "AWS_BEDROCK": "bedrock", - "VERTEX_AI": "vertex_ai", - "TOGETHER_AI": "together_ai", - "FIREWORKS_AI": "fireworks_ai", - "REPLICATE": "replicate", - "PERPLEXITY": "perplexity", - "ANYSCALE": "anyscale", - "DEEPINFRA": "deepinfra", - "CEREBRAS": "cerebras", - "SAMBANOVA": "sambanova", - "AI21": "ai21", - "CLOUDFLARE": "cloudflare", - "DATABRICKS": "databricks", - "DEEPSEEK": "openai", - "ALIBABA_QWEN": "openai", - "MOONSHOT": "openai", - "ZHIPU": "openai", - "MINIMAX": "openai", - } - provider_prefix = provider_map.get( - global_config["provider"], global_config["provider"].lower() - ) - model_string = f"{provider_prefix}/{global_config['model_name']}" - - # Create ChatLiteLLM instance from global config - litellm_kwargs = { - "model": model_string, - "api_key": global_config["api_key"], - } - - if global_config.get("api_base"): - litellm_kwargs["api_base"] = global_config["api_base"] - - if global_config.get("litellm_params"): - litellm_kwargs.update(global_config["litellm_params"]) - - if disable_streaming: - litellm_kwargs["disable_streaming"] = True + _, litellm_kwargs = _chat_litellm_from_resolved( + conn=global_connection, + model_id=global_model["model_id"], + disable_streaming=disable_streaming, + ) from app.agents.chat.runtime.llm_config import ( SanitizedChatLiteLLM, @@ -383,80 +370,18 @@ async def get_search_space_llm_instance( return SanitizedChatLiteLLM(**litellm_kwargs) - # Get the LLM configuration from database (NewLLMConfig) - result = await session.execute( - select(NewLLMConfig).where( - NewLLMConfig.id == llm_config_id, - NewLLMConfig.search_space_id == search_space_id, - ) - ) - llm_config = result.scalars().first() - - if not llm_config: + model = await _get_db_model(session, llm_config_id, search_space) + if not model or not _has_capability(model, "chat"): logger.error( - f"LLM config {llm_config_id} not found in search space {search_space_id}" + f"Chat model {llm_config_id} not found in search space {search_space_id}" ) return None - # Build the model string for litellm - if llm_config.custom_provider: - model_string = f"{llm_config.custom_provider}/{llm_config.model_name}" - else: - # Map provider enum to litellm format - provider_map = { - "OPENAI": "openai", - "ANTHROPIC": "anthropic", - "GROQ": "groq", - "COHERE": "cohere", - "GOOGLE": "gemini", - "OLLAMA": "ollama_chat", - "MISTRAL": "mistral", - "AZURE_OPENAI": "azure", - "OPENROUTER": "openrouter", - "COMETAPI": "cometapi", - "XAI": "xai", - "BEDROCK": "bedrock", - "AWS_BEDROCK": "bedrock", - "VERTEX_AI": "vertex_ai", - "TOGETHER_AI": "together_ai", - "FIREWORKS_AI": "fireworks_ai", - "REPLICATE": "replicate", - "PERPLEXITY": "perplexity", - "ANYSCALE": "anyscale", - "DEEPINFRA": "deepinfra", - "CEREBRAS": "cerebras", - "SAMBANOVA": "sambanova", - "AI21": "ai21", - "CLOUDFLARE": "cloudflare", - "DATABRICKS": "databricks", - "DEEPSEEK": "openai", - "ALIBABA_QWEN": "openai", - "MOONSHOT": "openai", - "ZHIPU": "openai", - "MINIMAX": "openai", - "GITHUB_MODELS": "github", - } - provider_prefix = provider_map.get( - llm_config.provider.value, llm_config.provider.value.lower() - ) - model_string = f"{provider_prefix}/{llm_config.model_name}" - - # Create ChatLiteLLM instance - litellm_kwargs = { - "model": model_string, - "api_key": llm_config.api_key, - } - - # Add optional parameters - if llm_config.api_base: - litellm_kwargs["api_base"] = llm_config.api_base - - # Add any additional litellm parameters - if llm_config.litellm_params: - litellm_kwargs.update(llm_config.litellm_params) - - if disable_streaming: - litellm_kwargs["disable_streaming"] = True + _, litellm_kwargs = _chat_litellm_from_resolved( + conn=model.connection, + model_id=model.model_id, + disable_streaming=disable_streaming, + ) from app.agents.chat.runtime.llm_config import ( SanitizedChatLiteLLM, @@ -474,7 +399,7 @@ async def get_search_space_llm_instance( async def get_agent_llm( session: AsyncSession, search_space_id: int, disable_streaming: bool = False ) -> ChatLiteLLM | ChatLiteLLMRouter | None: - """Get the search space's agent LLM instance for chat operations.""" + """Get the search space's chat model instance.""" return await get_search_space_llm_instance( session, search_space_id, @@ -488,22 +413,19 @@ async def get_vision_llm( ) -> ChatLiteLLM | ChatLiteLLMRouter | None: """Get the search space's vision LLM instance for screenshot analysis. - Resolves from the dedicated VisionLLMConfig system: + Resolves from the new connection/model role bindings: - Auto mode (ID 0): VisionLLMRouterService - - Global (negative ID): YAML configs - - DB (positive ID): VisionLLMConfig table + - Global (negative ID): virtual GLOBAL models from YAML + - DB (positive ID): Model + Connection tables Premium global configs are wrapped in :class:`QuotaCheckedVisionLLM` so each ``ainvoke`` debits the search-space owner's premium credit pool. User-owned BYOK configs and free global configs are returned unwrapped — they don't consume premium credit (issue M). """ - from app.db import VisionLLMConfig from app.services.quota_checked_vision_llm import QuotaCheckedVisionLLM from app.services.vision_llm_router_service import ( - VISION_PROVIDER_MAP, VisionLLMRouterService, - get_global_vision_llm_config, is_vision_auto_mode, ) @@ -516,13 +438,43 @@ async def get_vision_llm( logger.error(f"Search space {search_space_id} not found") return None - config_id = search_space.vision_llm_config_id + owner_user_id = search_space.user_id + + # Prefer the selected chat model when it is vision-capable. + chat_model_id = search_space.chat_model_id + if chat_model_id and chat_model_id != AUTO_MODE_ID: + if chat_model_id < 0: + chat_model = get_global_model(chat_model_id) + if chat_model and _has_capability(chat_model, "vision"): + global_connection = get_global_connection(chat_model["connection_id"]) + if global_connection: + model_string, litellm_kwargs = _chat_litellm_from_resolved( + conn=global_connection, + model_id=chat_model["model_id"], + ) + from app.agents.chat.runtime.llm_config import ( + SanitizedChatLiteLLM, + ) + + return SanitizedChatLiteLLM(**litellm_kwargs) + else: + chat_model = await _get_db_model(session, chat_model_id, search_space) + if chat_model and _has_capability(chat_model, "vision"): + _, litellm_kwargs = _chat_litellm_from_resolved( + conn=chat_model.connection, + model_id=chat_model.model_id, + ) + from app.agents.chat.runtime.llm_config import ( + SanitizedChatLiteLLM, + ) + + return SanitizedChatLiteLLM(**litellm_kwargs) + + config_id = search_space.vision_model_id if config_id is None: logger.error(f"No vision LLM configured for search space {search_space_id}") return None - owner_user_id = search_space.user_id - if is_vision_auto_mode(config_id): if not VisionLLMRouterService.is_initialized(): logger.error( @@ -546,34 +498,24 @@ async def get_vision_llm( return None if config_id < 0: - global_cfg = get_global_vision_llm_config(config_id) - if not global_cfg: - logger.error(f"Global vision LLM config {config_id} not found") + global_model = get_global_model(config_id) + if not global_model or not _has_capability(global_model, "vision"): + logger.error(f"Global vision model {config_id} not found") return None - if global_cfg.get("custom_provider"): - provider_prefix = global_cfg["custom_provider"] - model_string = f"{provider_prefix}/{global_cfg['model_name']}" - else: - provider_prefix = VISION_PROVIDER_MAP.get( - global_cfg["provider"].upper(), - global_cfg["provider"].lower(), + global_connection = get_global_connection(global_model["connection_id"]) + if not global_connection: + logger.error( + "Global connection %s not found for model %s", + global_model["connection_id"], + config_id, ) - model_string = f"{provider_prefix}/{global_cfg['model_name']}" + return None - litellm_kwargs = { - "model": model_string, - "api_key": global_cfg["api_key"], - } - api_base = resolve_api_base( - provider=global_cfg.get("provider"), - provider_prefix=provider_prefix, - config_api_base=global_cfg.get("api_base"), + model_string, litellm_kwargs = _chat_litellm_from_resolved( + conn=global_connection, + model_id=global_model["model_id"], ) - if api_base: - litellm_kwargs["api_base"] = api_base - if global_cfg.get("litellm_params"): - litellm_kwargs.update(global_cfg["litellm_params"]) from app.agents.chat.runtime.llm_config import ( SanitizedChatLiteLLM, @@ -581,7 +523,7 @@ async def get_vision_llm( inner_llm = SanitizedChatLiteLLM(**litellm_kwargs) - billing_tier = str(global_cfg.get("billing_tier", "free")).lower() + billing_tier = str(global_model.get("billing_tier", "free")).lower() if billing_tier == "premium": return QuotaCheckedVisionLLM( inner_llm, @@ -589,47 +531,23 @@ async def get_vision_llm( search_space_id=search_space_id, billing_tier=billing_tier, base_model=model_string, - quota_reserve_tokens=global_cfg.get("quota_reserve_tokens"), + quota_reserve_tokens=global_model.get("catalog", {}).get( + "quota_reserve_tokens" + ), ) return inner_llm - # User-owned (positive ID) BYOK configs — always free. - result = await session.execute( - select(VisionLLMConfig).where( - VisionLLMConfig.id == config_id, - VisionLLMConfig.search_space_id == search_space_id, - ) - ) - vision_cfg = result.scalars().first() - if not vision_cfg: + model = await _get_db_model(session, config_id, search_space) + if not model or not _has_capability(model, "vision"): logger.error( - f"Vision LLM config {config_id} not found in search space {search_space_id}" + f"Vision model {config_id} not found in search space {search_space_id}" ) return None - if vision_cfg.custom_provider: - provider_prefix = vision_cfg.custom_provider - model_string = f"{provider_prefix}/{vision_cfg.model_name}" - else: - provider_prefix = VISION_PROVIDER_MAP.get( - vision_cfg.provider.value.upper(), - vision_cfg.provider.value.lower(), - ) - model_string = f"{provider_prefix}/{vision_cfg.model_name}" - - litellm_kwargs = { - "model": model_string, - "api_key": vision_cfg.api_key, - } - api_base = resolve_api_base( - provider=vision_cfg.provider.value, - provider_prefix=provider_prefix, - config_api_base=vision_cfg.api_base, + _, litellm_kwargs = _chat_litellm_from_resolved( + conn=model.connection, + model_id=model.model_id, ) - if api_base: - litellm_kwargs["api_base"] = api_base - if vision_cfg.litellm_params: - litellm_kwargs.update(vision_cfg.litellm_params) from app.agents.chat.runtime.llm_config import ( SanitizedChatLiteLLM, diff --git a/surfsense_backend/app/services/vision_llm_router_service.py b/surfsense_backend/app/services/vision_llm_router_service.py index ed5de921c..0c7182ecf 100644 --- a/surfsense_backend/app/services/vision_llm_router_service.py +++ b/surfsense_backend/app/services/vision_llm_router_service.py @@ -3,29 +3,17 @@ from typing import Any from litellm import Router -from app.services.provider_api_base import resolve_api_base +from app.services.model_resolver import ( + NATIVE_PROVIDER_PREFIX, + native_connection_from_config, + to_litellm, +) logger = logging.getLogger(__name__) VISION_AUTO_MODE_ID = 0 -VISION_PROVIDER_MAP = { - "OPENAI": "openai", - "ANTHROPIC": "anthropic", - "GOOGLE": "gemini", - "AZURE_OPENAI": "azure", - "VERTEX_AI": "vertex_ai", - "BEDROCK": "bedrock", - "XAI": "xai", - "OPENROUTER": "openrouter", - "OLLAMA": "ollama_chat", - "GROQ": "groq", - "TOGETHER_AI": "together_ai", - "FIREWORKS_AI": "fireworks_ai", - "DEEPSEEK": "openai", - "MISTRAL": "mistral", - "CUSTOM": "custom", -} +VISION_PROVIDER_MAP = NATIVE_PROVIDER_PREFIX class VisionLLMRouterService: @@ -110,32 +98,11 @@ class VisionLLMRouterService: if not config.get("model_name") or not config.get("api_key"): return None - provider = config.get("provider", "").upper() - if config.get("custom_provider"): - provider_prefix = config["custom_provider"] - model_string = f"{provider_prefix}/{config['model_name']}" - else: - provider_prefix = VISION_PROVIDER_MAP.get(provider, provider.lower()) - model_string = f"{provider_prefix}/{config['model_name']}" - - litellm_params: dict[str, Any] = { - "model": model_string, - "api_key": config.get("api_key"), - } - - api_base = resolve_api_base( - provider=provider, - provider_prefix=provider_prefix, - config_api_base=config.get("api_base"), + model_string, resolved_kwargs = to_litellm( + native_connection_from_config(config), + config["model_name"], ) - if api_base: - litellm_params["api_base"] = api_base - - if config.get("api_version"): - litellm_params["api_version"] = config["api_version"] - - if config.get("litellm_params"): - litellm_params.update(config["litellm_params"]) + litellm_params: dict[str, Any] = {"model": model_string, **resolved_kwargs} deployment: dict[str, Any] = { "model_name": "auto",