diff --git a/surfsense_backend/app/agents/chat/runtime/llm_config.py b/surfsense_backend/app/agents/chat/runtime/llm_config.py index 03d7f548e..b9344e001 100644 --- a/surfsense_backend/app/agents/chat/runtime/llm_config.py +++ b/surfsense_backend/app/agents/chat/runtime/llm_config.py @@ -2,9 +2,9 @@ LLM configuration utilities for SurfSense agents. This module provides functions for loading LLM configurations from: -1. Auto mode (ID 0) - Uses LiteLLM Router for load balancing +1. Auto mode (ID 0) - Resolved by callers to a concrete model-connection model 2. YAML files (global configs with negative IDs) -3. Database NewLLMConfig table (user-created configs with positive IDs) +3. Database model-connections table (user-created configs with positive IDs) It also provides utilities for creating ChatLiteLLM instances and managing prompt configurations. @@ -33,9 +33,7 @@ from app.agents.chat.runtime.prompt_caching import ( from app.services.llm_router_service import ( AUTO_MODE_ID, ChatLiteLLMRouter, - LLMRouterService, _sanitize_content, - get_auto_mode_llm, is_auto_mode, ) @@ -92,14 +90,6 @@ class SanitizedChatLiteLLM(ChatLiteLLM): yield chunk -# Re-exported under the historical name ``PROVIDER_MAP``. Source of truth lives -# in provider_capabilities so the YAML loader can resolve prefixes during -# app.config init without importing the agent/tools tree. -from app.services.provider_capabilities import ( # noqa: E402 - _PROVIDER_PREFIX_MAP as PROVIDER_MAP, -) - - def _attach_model_profile(llm: ChatLiteLLM, model_string: str) -> None: """Attach a ``profile`` dict to ChatLiteLLM with model context metadata.""" try: @@ -122,7 +112,8 @@ class AgentConfig: Complete configuration for the SurfSense agent. This combines LLM settings with prompt configuration from NewLLMConfig. - Supports Auto mode (ID 0) which uses LiteLLM Router for load balancing. + Supports Auto mode metadata (ID 0). Runtime callers must resolve Auto to + a concrete global or BYOK model before constructing ChatLiteLLM. """ # LLM Model Settings @@ -219,7 +210,7 @@ class AgentConfig: # BYOK rows have no curated flag; ask LiteLLM (default-allow on # unknown). The streaming safety net still blocks explicit text-only. supports_image_input=derive_supports_image_input( - provider=provider_value, + litellm_provider=provider_value.lower(), model_name=config.model_name, base_model=base_model, custom_provider=config.custom_provider, @@ -238,7 +229,7 @@ class AgentConfig: system_instructions = yaml_config.get("system_instructions", "") - provider = yaml_config.get("provider", "").upper() + provider = yaml_config.get("litellm_provider", "") model_name = yaml_config.get("model_name", "") custom_provider = yaml_config.get("custom_provider") litellm_params = yaml_config.get("litellm_params") or {} @@ -254,7 +245,7 @@ class AgentConfig: supports_image_input = bool(yaml_config.get("supports_image_input")) else: supports_image_input = derive_supports_image_input( - provider=provider, + litellm_provider=provider, model_name=model_name, base_model=base_model, custom_provider=custom_provider, @@ -383,9 +374,6 @@ async def load_agent_config( ) -> "AgentConfig | None": """Main config loader: id 0 -> Auto mode; negative -> YAML; positive -> DB.""" if is_auto_mode(config_id): - if not LLMRouterService.is_initialized(): - print("Error: Auto mode requested but LLM Router not initialized") - return None return AgentConfig.from_auto_mode() if config_id < 0: @@ -408,9 +396,8 @@ def create_chat_litellm_from_config(llm_config: dict) -> ChatLiteLLM | None: if llm_config.get("custom_provider"): model_string = f"{llm_config['custom_provider']}/{llm_config['model_name']}" else: - provider = llm_config.get("provider", "").upper() - provider_prefix = PROVIDER_MAP.get(provider, provider.lower()) - model_string = f"{provider_prefix}/{llm_config['model_name']}" + litellm_provider = llm_config.get("litellm_provider", "openai") + model_string = f"{litellm_provider}/{llm_config['model_name']}" litellm_kwargs = { "model": model_string, @@ -433,29 +420,15 @@ def create_chat_litellm_from_config(llm_config: dict) -> ChatLiteLLM | None: def create_chat_litellm_from_agent_config( agent_config: AgentConfig, ) -> ChatLiteLLM | ChatLiteLLMRouter | None: - """Create a ChatLiteLLM (or, for Auto mode, a load-balancing router) from config.""" + """Create a ChatLiteLLM from an already resolved concrete model config.""" if agent_config.is_auto_mode: - if not LLMRouterService.is_initialized(): - print("Error: Auto mode requested but LLM Router not initialized") - return None - try: - router_llm = get_auto_mode_llm() - if router_llm is not None: - # Universal injection points only: auto-mode fans out across - # providers, so provider-specific kwargs have no known target. - apply_litellm_prompt_caching(router_llm, agent_config=agent_config) - return router_llm - except Exception as e: - print(f"Error creating ChatLiteLLMRouter: {e}") - return None + print("Error: Auto mode must be resolved to a concrete model before LLM creation") + return None if agent_config.custom_provider: model_string = f"{agent_config.custom_provider}/{agent_config.model_name}" else: - provider_prefix = PROVIDER_MAP.get( - agent_config.provider, agent_config.provider.lower() - ) - model_string = f"{provider_prefix}/{agent_config.model_name}" + model_string = f"{agent_config.provider}/{agent_config.model_name}" litellm_kwargs = { "model": model_string, diff --git a/surfsense_backend/app/routes/anonymous_chat_routes.py b/surfsense_backend/app/routes/anonymous_chat_routes.py index ad3277375..aba1a3a12 100644 --- a/surfsense_backend/app/routes/anonymous_chat_routes.py +++ b/surfsense_backend/app/routes/anonymous_chat_routes.py @@ -132,7 +132,7 @@ async def list_anonymous_models(): id=cfg.get("id", 0), name=cfg.get("name", ""), description=cfg.get("description"), - provider=cfg.get("provider", ""), + provider=cfg.get("litellm_provider", ""), model_name=cfg.get("model_name", ""), billing_tier=cfg.get("billing_tier", "free"), is_premium=cfg.get("billing_tier", "free") == "premium", @@ -161,7 +161,7 @@ async def get_anonymous_model(slug: str): id=cfg.get("id", 0), name=cfg.get("name", ""), description=cfg.get("description"), - provider=cfg.get("provider", ""), + provider=cfg.get("litellm_provider", ""), model_name=cfg.get("model_name", ""), billing_tier=cfg.get("billing_tier", "free"), is_premium=cfg.get("billing_tier", "free") == "premium", diff --git a/surfsense_backend/app/routes/vision_llm_routes.py b/surfsense_backend/app/routes/vision_llm_routes.py index e4f08f604..df218daac 100644 --- a/surfsense_backend/app/routes/vision_llm_routes.py +++ b/surfsense_backend/app/routes/vision_llm_routes.py @@ -96,7 +96,7 @@ async def get_global_vision_llm_configs( "id": cfg.get("id"), "name": cfg.get("name"), "description": cfg.get("description"), - "provider": cfg.get("provider"), + "provider": cfg.get("litellm_provider"), "custom_provider": cfg.get("custom_provider"), "model_name": cfg.get("model_name"), "api_base": cfg.get("api_base") or None, diff --git a/surfsense_backend/app/services/auto_model_pin_service.py b/surfsense_backend/app/services/auto_model_pin_service.py index 9bbca8669..ee8c4b8dc 100644 --- a/surfsense_backend/app/services/auto_model_pin_service.py +++ b/surfsense_backend/app/services/auto_model_pin_service.py @@ -23,9 +23,10 @@ from uuid import UUID from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import selectinload from app.config import config -from app.db import NewChatThread +from app.db import Connection, Model, NewChatThread from app.services.quality_score import _QUALITY_TOP_K from app.services.token_quota_service import TokenQuotaService @@ -61,11 +62,20 @@ def _is_usable_global_config(cfg: dict) -> bool: return bool( cfg.get("id") is not None and cfg.get("model_name") - and cfg.get("provider") + and cfg.get("litellm_provider") and cfg.get("api_key") ) +def _has_capability(model: dict | Model, capability: str) -> bool: + caps = ( + model.get("capabilities", {}) + if isinstance(model, dict) + else model.capabilities or {} + ) + return bool(caps.get(capability)) + + def _prune_runtime_cooldowns(now_ts: float | None = None) -> None: now = time.time() if now_ts is None else now_ts stale = [cid for cid, until in _runtime_cooldown_until.items() if until <= now] @@ -186,15 +196,19 @@ def _cfg_supports_image_input(cfg: dict) -> bool: else None ) return derive_supports_image_input( - provider=cfg.get("provider"), + litellm_provider=cfg.get("litellm_provider"), model_name=cfg.get("model_name"), base_model=base_model, custom_provider=cfg.get("custom_provider"), ) -def _global_candidates(*, requires_image_input: bool = False) -> list[dict]: - """Return Auto-eligible global cfgs. +def _global_candidates( + *, + capability: str = "chat", + requires_image_input: bool = False, +) -> list[dict]: + """Return Auto-eligible global virtual models. Drops cfgs flagged ``health_gated`` (best non-null OpenRouter uptime below ``_HEALTH_GATE_UPTIME_PCT``) so chronically broken providers @@ -205,17 +219,135 @@ def _global_candidates(*, requires_image_input: bool = False) -> list[dict]: filters out configs whose ``supports_image_input`` resolves to False so a text-only deployment can't be pinned for an image request. """ - candidates = [ - cfg + connection_by_id = { + int(conn.get("id")): conn + for conn in config.GLOBAL_CONNECTIONS + if conn.get("id") is not None + } + config_by_model_name = { + cfg.get("model_name"): cfg for cfg in config.GLOBAL_LLM_CONFIGS if _is_usable_global_config(cfg) - and not cfg.get("health_gated") - and not _is_runtime_cooled_down(int(cfg.get("id", 0))) - and (not requires_image_input or _cfg_supports_image_input(cfg)) - ] + } + candidates: list[dict] = [] + for model in config.GLOBAL_MODELS: + model_id = int(model.get("id", 0)) + if model_id >= 0 or _is_runtime_cooled_down(model_id): + continue + if not _has_capability(model, capability): + continue + cfg = config_by_model_name.get(model.get("model_id")) or {} + if cfg.get("health_gated"): + continue + if requires_image_input and not _has_capability(model, "vision"): + continue + if requires_image_input and cfg and not _cfg_supports_image_input(cfg): + continue + connection = connection_by_id.get(int(model.get("connection_id", 0))) + if not connection: + continue + catalog = model.get("catalog") or {} + candidates.append( + { + "id": model_id, + "model_id": model.get("model_id"), + "source": "global", + "connection": connection, + "capabilities": model.get("capabilities") or {}, + "billing_tier": model.get("billing_tier", "free"), + "litellm_provider": connection.get("litellm_provider"), + "model_name": model.get("model_id"), + "auto_pin_tier": catalog.get("auto_pin_tier") + or cfg.get("auto_pin_tier") + or "A", + "quality_score": catalog.get("quality_score") + or cfg.get("quality_score") + or cfg.get("quality_score_static") + or 50, + } + ) return sorted(candidates, key=lambda c: int(c.get("id", 0))) +async def _db_candidates( + session: AsyncSession, + *, + search_space_id: int, + user_id: str | UUID | None, + capability: str, + requires_image_input: bool = False, +) -> list[dict]: + parsed_user_id = _to_uuid(user_id) + stmt = ( + select(Model) + .options(selectinload(Model.connection)) + .join(Connection, Model.connection_id == Connection.id) + .where(Model.enabled.is_(True), Connection.enabled.is_(True)) + ) + result = await session.execute(stmt) + candidates: list[dict] = [] + for model in result.scalars().all(): + conn = model.connection + if not conn: + continue + if conn.search_space_id is not None and conn.search_space_id != search_space_id: + continue + if conn.user_id is not None and parsed_user_id is not None and conn.user_id != parsed_user_id: + continue + if conn.user_id is not None and parsed_user_id is None: + continue + if not _has_capability(model, capability): + continue + if requires_image_input and not _has_capability(model, "vision"): + continue + model_id = int(model.id) + if _is_runtime_cooled_down(model_id): + continue + catalog = model.catalog or {} + candidates.append( + { + "id": model_id, + "model_id": model.model_id, + "source": "db", + "connection": conn, + "capabilities": model.capabilities or {}, + "billing_tier": "byok", + "litellm_provider": conn.litellm_provider, + "model_name": model.model_id, + "auto_pin_tier": catalog.get("auto_pin_tier") or "BYOK", + "quality_score": catalog.get("quality_score") or 75, + } + ) + return sorted(candidates, key=lambda c: int(c.get("id", 0))) + + +async def auto_model_candidates( + session: AsyncSession, + *, + search_space_id: int, + user_id: str | UUID | None, + capability: str, + requires_image_input: bool = False, + exclude_model_ids: set[int] | None = None, +) -> list[dict]: + excluded_ids = {int(mid) for mid in (exclude_model_ids or set())} + db_candidates = await _db_candidates( + session, + search_space_id=search_space_id, + user_id=user_id, + capability=capability, + requires_image_input=requires_image_input, + ) + candidates = [ + *_global_candidates( + capability=capability, + requires_image_input=requires_image_input, + ), + *db_candidates, + ] + return [c for c in candidates if int(c.get("id", 0)) not in excluded_ids] + + def _tier_of(cfg: dict) -> str: return str(cfg.get("billing_tier", "free")).lower() @@ -223,8 +355,9 @@ def _tier_of(cfg: dict) -> str: def _is_preferred_premium_auto_config(cfg: dict) -> bool: """Return True for the operator-preferred premium Auto model.""" return ( - _tier_of(cfg) == "premium" - and str(cfg.get("provider", "")).upper() == "AZURE_OPENAI" + cfg.get("source") == "global" + and _tier_of(cfg) == "premium" + and str(cfg.get("litellm_provider", "")).lower() == "azure" and str(cfg.get("model_name", "")).lower() == "gpt-5.4" ) @@ -251,6 +384,11 @@ def _select_pin(eligible: list[dict], thread_id: int) -> tuple[dict, int]: return top_k[idx], len(top_k) +def choose_auto_model_candidate(candidates: list[dict], seed_id: int) -> dict: + selected, _ = _select_pin(candidates, seed_id) + return selected + + def _to_uuid(user_id: str | UUID | None) -> UUID | None: if user_id is None: return None @@ -326,20 +464,23 @@ async def resolve_or_get_pinned_llm_config_id( ) excluded_ids = {int(cid) for cid in (exclude_config_ids or set())} - candidates = [ - c - for c in _global_candidates(requires_image_input=requires_image_input) - if int(c.get("id", 0)) not in excluded_ids - ] + candidates = await auto_model_candidates( + session, + search_space_id=search_space_id, + user_id=user_id, + capability="chat", + requires_image_input=requires_image_input, + exclude_model_ids=excluded_ids, + ) if not candidates: if requires_image_input: # Distinguish the "no vision-capable cfg" case from generic # "no usable cfg" so the streaming task can map this to the # MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT SSE error. raise ValueError( - "No vision-capable global LLM configs are available for Auto mode" + "No vision-capable LLM models are available for Auto mode" ) - raise ValueError("No usable global LLM configs are available for Auto mode") + raise ValueError("No usable LLM models are available for Auto mode") candidate_by_id = {int(c["id"]): c for c in candidates} # Reuse an existing valid pin without re-checking current quota (no silent @@ -379,24 +520,13 @@ async def resolve_or_get_pinned_llm_config_id( # log that explicitly so operators can correlate the re-pin with # the user's image attachment instead of suspecting a cooldown. if requires_image_input: - try: - pinned_global = next( - c - for c in config.GLOBAL_LLM_CONFIGS - if int(c.get("id", 0)) == int(pinned_id) - ) - except StopIteration: - pinned_global = None - if pinned_global is not None and not _cfg_supports_image_input( - pinned_global - ): - logger.info( - "auto_pin_repinned_for_image thread_id=%s search_space_id=%s " - "previous_config_id=%s", - thread_id, - search_space_id, - pinned_id, - ) + logger.info( + "auto_pin_repinned_for_image thread_id=%s search_space_id=%s " + "previous_config_id=%s", + thread_id, + search_space_id, + pinned_id, + ) logger.info( "auto_pin_invalid thread_id=%s search_space_id=%s pinned_config_id=%s", thread_id, diff --git a/surfsense_backend/app/services/llm_router_service.py b/surfsense_backend/app/services/llm_router_service.py index 69feb30eb..a151a0d6e 100644 --- a/surfsense_backend/app/services/llm_router_service.py +++ b/surfsense_backend/app/services/llm_router_service.py @@ -30,11 +30,7 @@ from litellm.exceptions import ( ) from pydantic import Field -from app.services.model_resolver import ( - NATIVE_PROVIDER_PREFIX, - native_connection_from_config, - to_litellm, -) +from app.services.model_resolver import native_connection_from_config, to_litellm from app.utils.perf import get_perf_logger litellm.json_logs = False @@ -101,10 +97,6 @@ def _sanitize_content(content: Any) -> Any: # Special ID for Auto mode - uses router for load balancing AUTO_MODE_ID = 0 -# Historical export kept for callers that still import PROVIDER_MAP. -PROVIDER_MAP = NATIVE_PROVIDER_PREFIX - - class LLMRouterService: """ Singleton service for managing LiteLLM Router. diff --git a/surfsense_backend/app/services/llm_service.py b/surfsense_backend/app/services/llm_service.py index 75451d01f..86a9c8556 100644 --- a/surfsense_backend/app/services/llm_service.py +++ b/surfsense_backend/app/services/llm_service.py @@ -10,13 +10,11 @@ from sqlalchemy.orm import selectinload from app.config import config from app.db import Model, SearchSpace -from app.services.llm_router_service import ( - AUTO_MODE_ID, - ChatLiteLLMRouter, - LLMRouterService, - get_auto_mode_llm, - is_auto_mode, +from app.services.auto_model_pin_service import ( + auto_model_candidates, + choose_auto_model_candidate, ) +from app.services.llm_router_service import AUTO_MODE_ID, ChatLiteLLMRouter, is_auto_mode from app.services.model_resolver import native_connection_from_config, to_litellm from app.services.token_tracking_service import token_tracker @@ -78,7 +76,7 @@ def _legacy_config_connection( api_version: str | None = None, ) -> tuple[str, dict]: cfg = { - "provider": provider, + "litellm_provider": provider.lower(), "model_name": model_name, "api_key": api_key, "api_base": api_base, @@ -325,23 +323,21 @@ async def get_search_space_llm_instance( logger.error(f"No {role} LLM configured for search space {search_space_id}") return None - # Check for Auto mode (ID 0) - use router for load balancing + # Auto mode resolves to one concrete global or BYOK model from the + # unified model-connections catalog. if is_auto_mode(llm_config_id): - if not LLMRouterService.is_initialized(): - logger.error( - "Auto mode requested but LLM Router not initialized. " - "Ensure global_llm_config.yaml exists with valid configs." - ) - return None - - try: - logger.debug( - f"Using Auto mode (LLM Router) for search space {search_space_id}, role {role}" - ) - return get_auto_mode_llm(streaming=not disable_streaming) - except Exception as e: - logger.error(f"Failed to create ChatLiteLLMRouter: {e}") + candidates = await auto_model_candidates( + session, + search_space_id=search_space_id, + user_id=search_space.user_id, + capability="chat", + ) + if not candidates: + logger.error("No chat-capable models available for Auto mode") return None + llm_config_id = int( + choose_auto_model_candidate(candidates, search_space_id)["id"] + ) # Check if this is a global virtual model (negative ID) if llm_config_id < 0: @@ -414,7 +410,7 @@ async def get_vision_llm( """Get the search space's vision LLM instance for screenshot analysis. Resolves from the new connection/model role bindings: - - Auto mode (ID 0): VisionLLMRouterService + - Auto mode (ID 0): unified global/BYOK model candidate selection - Global (negative ID): virtual GLOBAL models from YAML - DB (positive ID): Model + Connection tables @@ -424,10 +420,7 @@ async def get_vision_llm( unwrapped — they don't consume premium credit (issue M). """ from app.services.quota_checked_vision_llm import QuotaCheckedVisionLLM - from app.services.vision_llm_router_service import ( - VisionLLMRouterService, - is_vision_auto_mode, - ) + from app.services.vision_llm_router_service import is_vision_auto_mode try: result = await session.execute( @@ -476,26 +469,16 @@ async def get_vision_llm( return None if is_vision_auto_mode(config_id): - if not VisionLLMRouterService.is_initialized(): - logger.error( - "Vision Auto mode requested but Vision LLM Router not initialized" - ) - return None - try: - # Auto mode is currently treated as free at the wrapper - # level — the underlying router can dispatch to either - # premium or free YAML configs but routing decisions are - # opaque. If/when we want to bill Auto-routed vision - # calls we'd need to thread the resolved deployment's - # billing_tier back from the router. For now we keep - # parity with chat Auto, which also doesn't pre-classify. - return ChatLiteLLMRouter( - router=VisionLLMRouterService.get_router(), - streaming=True, - ) - except Exception as e: - logger.error(f"Failed to create vision ChatLiteLLMRouter: {e}") + candidates = await auto_model_candidates( + session, + search_space_id=search_space_id, + user_id=owner_user_id, + capability="vision", + ) + if not candidates: + logger.error("No vision-capable models available for Auto mode") return None + config_id = int(choose_auto_model_candidate(candidates, search_space_id)["id"]) if config_id < 0: global_model = get_global_model(config_id) diff --git a/surfsense_backend/app/services/model_list_service.py b/surfsense_backend/app/services/model_list_service.py index 33837a8a0..1ef0b0c90 100644 --- a/surfsense_backend/app/services/model_list_service.py +++ b/surfsense_backend/app/services/model_list_service.py @@ -154,19 +154,19 @@ def _process_models(raw_models: list[dict]) -> list[dict]: } ) - # 2) Emit for the native provider when we have a mapping - native_provider = OPENROUTER_SLUG_TO_PROVIDER.get(provider_slug) - if native_provider: + # 2) Emit for the direct provider when we have a mapping + direct_provider = OPENROUTER_SLUG_TO_PROVIDER.get(provider_slug) + if direct_provider: # Google's Gemini API only serves gemini-* models. # Open-source models like gemma-* are NOT available through it. - if native_provider == "GOOGLE" and not model_name.startswith("gemini-"): + if direct_provider == "GOOGLE" and not model_name.startswith("gemini-"): continue processed.append( { "value": model_name, "label": name, - "provider": native_provider, + "provider": direct_provider, "context_window": context_window, } ) diff --git a/surfsense_backend/app/services/model_resolver.py b/surfsense_backend/app/services/model_resolver.py index ec485a5ae..ffa77a9a2 100644 --- a/surfsense_backend/app/services/model_resolver.py +++ b/surfsense_backend/app/services/model_resolver.py @@ -9,53 +9,12 @@ from __future__ import annotations from collections.abc import Mapping from typing import TYPE_CHECKING, Any -from app.services.provider_api_base import resolve_api_base - if TYPE_CHECKING: from app.db import Connection PROTOCOL_OLLAMA = "OLLAMA" PROTOCOL_OPENAI_COMPATIBLE = "OPENAI_COMPATIBLE" -PROTOCOL_NATIVE = "NATIVE" - -NATIVE_PROVIDER_PREFIX: dict[str, str] = { - "OPENAI": "openai", - "ANTHROPIC": "anthropic", - "GROQ": "groq", - "COHERE": "cohere", - "GOOGLE": "gemini", - "MISTRAL": "mistral", - "AZURE_OPENAI": "azure", - "AZURE": "azure", - "OPENROUTER": "openrouter", - "COMETAPI": "cometapi", - "XAI": "xai", - "BEDROCK": "bedrock", - "AWS_BEDROCK": "bedrock", - "VERTEX_AI": "vertex_ai", - "TOGETHER_AI": "together_ai", - "FIREWORKS_AI": "fireworks_ai", - "DEEPSEEK": "openai", - "ALIBABA_QWEN": "openai", - "MOONSHOT": "openai", - "ZHIPU": "openai", - "GITHUB_MODELS": "github", - "REPLICATE": "replicate", - "PERPLEXITY": "perplexity", - "ANYSCALE": "anyscale", - "DEEPINFRA": "deepinfra", - "CEREBRAS": "cerebras", - "SAMBANOVA": "sambanova", - "AI21": "ai21", - "CLOUDFLARE": "cloudflare", - "DATABRICKS": "databricks", - "HUGGINGFACE": "huggingface", - "MINIMAX": "openai", - "RECRAFT": "recraft", - "XINFERENCE": "xinference", - "NSCALE": "nscale", - "CUSTOM": "custom", -} +PROTOCOL_ANTHROPIC = "ANTHROPIC" def ensure_v1(base_url: str | None) -> str | None: @@ -77,6 +36,23 @@ def _protocol_value(protocol: Any) -> str: return getattr(protocol, "value", str(protocol)) +def default_litellm_provider(protocol: Any) -> str: + protocol_value = _protocol_value(protocol) + defaults = { + PROTOCOL_OLLAMA: "ollama_chat", + PROTOCOL_ANTHROPIC: "anthropic", + PROTOCOL_OPENAI_COMPATIBLE: "openai", + } + return defaults.get(protocol_value, "openai") + + +def _execution_api_base(protocol: str, base_url: str | None) -> str | None: + del protocol + if not base_url: + return None + return base_url.rstrip("/") + + def to_litellm( conn: Connection | Mapping[str, Any], model_id: str, @@ -85,38 +61,19 @@ def to_litellm( protocol = _protocol_value(_conn_value(conn, "protocol")) base_url = _conn_value(conn, "base_url") api_key = _conn_value(conn, "api_key") - native_provider = _conn_value(conn, "native_provider") + litellm_provider = ( + _conn_value(conn, "litellm_provider") or default_litellm_provider(protocol) + ) extra = _conn_value(conn, "extra") or {} kwargs: dict[str, Any] = {} if api_key: kwargs["api_key"] = api_key - if protocol == PROTOCOL_OLLAMA: - model_string = f"ollama_chat/{model_id}" - if base_url: - kwargs["api_base"] = base_url.rstrip("/") - elif protocol == PROTOCOL_OPENAI_COMPATIBLE: - model_string = f"openai/{model_id}" - api_base = ensure_v1(base_url) - if api_base: - kwargs["api_base"] = api_base - else: - provider_key = (native_provider or "").upper() - prefix = NATIVE_PROVIDER_PREFIX.get(provider_key, provider_key.lower()) - if prefix == "custom": - custom_provider = extra.get("custom_provider") or native_provider - model_string = f"{custom_provider}/{model_id}" if custom_provider else model_id - else: - model_string = f"{prefix}/{model_id}" - - api_base = resolve_api_base( - provider=provider_key, - provider_prefix=prefix, - config_api_base=base_url, - ) - if api_base: - kwargs["api_base"] = api_base + model_string = f"{litellm_provider}/{model_id}" if litellm_provider else model_id + api_base = _execution_api_base(protocol, base_url) + if api_base: + kwargs["api_base"] = api_base if api_version := extra.get("api_version"): kwargs["api_version"] = api_version @@ -126,18 +83,21 @@ def to_litellm( def native_connection_from_config(config: Mapping[str, Any]) -> dict[str, Any]: - """Build an in-memory NATIVE connection mapping from a legacy/global config.""" - provider = str(config.get("provider") or config.get("custom_provider") or "CUSTOM") + """Build an in-memory connection mapping from a global config.""" + protocol = str(config.get("protocol") or PROTOCOL_OPENAI_COMPATIBLE) + litellm_provider = str( + config.get("litellm_provider") + or config.get("custom_provider") + or default_litellm_provider(protocol) + ) extra: dict[str, Any] = { "litellm_params": config.get("litellm_params") or {}, } if config.get("api_version"): extra["api_version"] = config.get("api_version") - if config.get("custom_provider"): - extra["custom_provider"] = config.get("custom_provider") return { - "protocol": PROTOCOL_NATIVE, - "native_provider": provider, + "protocol": protocol, + "litellm_provider": litellm_provider, "base_url": config.get("api_base") or None, "api_key": config.get("api_key") or None, "extra": extra, @@ -145,7 +105,7 @@ def native_connection_from_config(config: Mapping[str, Any]) -> dict[str, Any]: __all__ = [ - "NATIVE_PROVIDER_PREFIX", + "default_litellm_provider", "ensure_v1", "native_connection_from_config", "to_litellm", diff --git a/surfsense_backend/app/services/vision_llm_router_service.py b/surfsense_backend/app/services/vision_llm_router_service.py index 0c7182ecf..0ff716324 100644 --- a/surfsense_backend/app/services/vision_llm_router_service.py +++ b/surfsense_backend/app/services/vision_llm_router_service.py @@ -3,19 +3,12 @@ from typing import Any from litellm import Router -from app.services.model_resolver import ( - NATIVE_PROVIDER_PREFIX, - native_connection_from_config, - to_litellm, -) +from app.services.model_resolver import native_connection_from_config, to_litellm logger = logging.getLogger(__name__) VISION_AUTO_MODE_ID = 0 -VISION_PROVIDER_MAP = NATIVE_PROVIDER_PREFIX - - class VisionLLMRouterService: _instance = None _router: Router | None = None @@ -141,12 +134,11 @@ def is_vision_auto_mode(config_id: int | None) -> bool: def build_vision_model_string( - provider: str, model_name: str, custom_provider: str | None + litellm_provider: str, model_name: str, custom_provider: str | None ) -> str: if custom_provider: return f"{custom_provider}/{model_name}" - prefix = VISION_PROVIDER_MAP.get(provider.upper(), provider.lower()) - return f"{prefix}/{model_name}" + return f"{litellm_provider}/{model_name}" def get_global_vision_llm_config(config_id: int) -> dict | None: diff --git a/surfsense_backend/app/services/vision_model_list_service.py b/surfsense_backend/app/services/vision_model_list_service.py index fc459910b..6eae8c455 100644 --- a/surfsense_backend/app/services/vision_model_list_service.py +++ b/surfsense_backend/app/services/vision_model_list_service.py @@ -97,16 +97,16 @@ def _process_vision_models(raw_models: list[dict]) -> list[dict]: } ) - native_provider = OPENROUTER_SLUG_TO_VISION_PROVIDER.get(provider_slug) - if native_provider: - if native_provider == "GOOGLE" and not model_name.startswith("gemini-"): + direct_provider = OPENROUTER_SLUG_TO_VISION_PROVIDER.get(provider_slug) + if direct_provider: + if direct_provider == "GOOGLE" and not model_name.startswith("gemini-"): continue processed.append( { "value": model_name, "label": name, - "provider": native_provider, + "provider": direct_provider, "context_window": context_window, } ) diff --git a/surfsense_backend/app/tasks/chat/streaming/flows/new_chat/llm_capability.py b/surfsense_backend/app/tasks/chat/streaming/flows/new_chat/llm_capability.py index 69b9f4ab8..f6fcf75d7 100644 --- a/surfsense_backend/app/tasks/chat/streaming/flows/new_chat/llm_capability.py +++ b/surfsense_backend/app/tasks/chat/streaming/flows/new_chat/llm_capability.py @@ -40,7 +40,7 @@ def check_image_input_capability( else None ) if not is_known_text_only_chat_model( - provider=agent_config.provider, + litellm_provider=agent_config.provider, model_name=agent_config.model_name, base_model=agent_base_model, custom_provider=agent_config.custom_provider, diff --git a/surfsense_backend/app/tasks/chat/streaming/flows/new_chat/title_gen.py b/surfsense_backend/app/tasks/chat/streaming/flows/new_chat/title_gen.py index fe3d210bb..d5e8c3729 100644 --- a/surfsense_backend/app/tasks/chat/streaming/flows/new_chat/title_gen.py +++ b/surfsense_backend/app/tasks/chat/streaming/flows/new_chat/title_gen.py @@ -80,7 +80,6 @@ async def _generate_title( from litellm import acompletion from app.services.llm_router_service import LLMRouterService - from app.services.provider_api_base import resolve_api_base from app.services.token_tracking_service import _turn_accumulator # Excludes this turn's own assistant row (pre-written by @@ -125,26 +124,12 @@ async def _generate_title( router = LLMRouterService.get_router() response = await router.acompletion(model="auto", messages=messages) else: - # Apply the same ``api_base`` cascade chat / vision / image-gen - # call sites use so we never inherit ``litellm.api_base`` - # (commonly set by ``AZURE_OPENAI_ENDPOINT``) when the chat - # config itself ships an empty ``api_base``. Without this the - # title-gen on an OpenRouter chat config would 404 against the - # inherited Azure endpoint — see ``provider_api_base`` for the - # same bug repro on the image-gen / vision paths. raw_model = getattr(llm, "model", "") or "" - provider_prefix = raw_model.split("/", 1)[0] if "/" in raw_model else None - provider_value = agent_config.provider if agent_config is not None else None - title_api_base = resolve_api_base( - provider=provider_value, - provider_prefix=provider_prefix, - config_api_base=getattr(llm, "api_base", None), - ) response = await acompletion( model=raw_model, messages=messages, api_key=getattr(llm, "api_key", None), - api_base=title_api_base, + api_base=getattr(llm, "api_base", None), ) usage_info = None diff --git a/surfsense_backend/app/tasks/chat/streaming/flows/shared/llm_bundle.py b/surfsense_backend/app/tasks/chat/streaming/flows/shared/llm_bundle.py index 7e2bc950b..f6870f5fa 100644 --- a/surfsense_backend/app/tasks/chat/streaming/flows/shared/llm_bundle.py +++ b/surfsense_backend/app/tasks/chat/streaming/flows/shared/llm_bundle.py @@ -1,8 +1,8 @@ """Load an LLM + AgentConfig bundle for a given config id. Handles both code paths uniformly: -- ``config_id >= 0`` → database-backed ``NewLLMConfig`` row (per-user/per-space). -- ``config_id < 0`` → YAML-defined global LLM config (built-in defaults). +- ``config_id > 0`` → database-backed model-connection ``Model`` row. +- ``config_id < 0`` → virtual global model materialized from YAML/OpenRouter. Returns ``(llm, agent_config, error_message)``; on success ``error_message`` is ``None``. The caller emits the friendly SSE error frame. @@ -12,15 +12,72 @@ from __future__ import annotations from typing import Any +from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import selectinload from app.agents.chat.runtime.llm_config import ( AgentConfig, - create_chat_litellm_from_agent_config, - create_chat_litellm_from_config, - load_agent_config, - load_global_llm_config_by_id, + SanitizedChatLiteLLM, ) +from app.config import config +from app.db import Model, SearchSpace +from app.services.model_resolver import to_litellm + + +def _agent_config_from_resolved( + *, + config_id: int, + config_name: str | None, + provider: str, + model_name: str, + api_key: str | None, + api_base: str | None, + litellm_params: dict | None, + supports_image_input: bool, + billing_tier: str = "free", +) -> AgentConfig: + return AgentConfig( + provider=provider, + model_name=model_name, + api_key=api_key or "", + api_base=api_base, + custom_provider=None, + litellm_params=litellm_params, + config_id=config_id, + config_name=config_name, + is_auto_mode=False, + billing_tier=billing_tier, + is_premium=billing_tier == "premium", + supports_image_input=supports_image_input, + ) + + +async def _load_search_space(session: AsyncSession, search_space_id: int) -> SearchSpace | None: + result = await session.execute(select(SearchSpace).where(SearchSpace.id == search_space_id)) + return result.scalars().first() + + +async def _load_db_model( + session: AsyncSession, + *, + model_id: int, + search_space: SearchSpace, +) -> Model | None: + result = await session.execute( + select(Model) + .options(selectinload(Model.connection)) + .where(Model.id == model_id, Model.enabled.is_(True)) + ) + model = result.scalars().first() + if not model or not model.connection or not model.connection.enabled: + return None + conn = model.connection + if conn.search_space_id is not None and conn.search_space_id != search_space.id: + return None + if conn.user_id is not None and conn.user_id != search_space.user_id: + return None + return model async def load_llm_bundle( @@ -29,29 +86,67 @@ async def load_llm_bundle( config_id: int, search_space_id: int, ) -> tuple[Any, AgentConfig | None, str | None]: - if config_id >= 0: - loaded_agent_config = await load_agent_config( - session=session, - config_id=config_id, - search_space_id=search_space_id, + search_space = await _load_search_space(session, search_space_id) + if not search_space: + return None, None, f"Search space {search_space_id} not found" + + if config_id > 0: + model = await _load_db_model( + session, + model_id=config_id, + search_space=search_space, ) - if not loaded_agent_config: + if not model or not (model.capabilities or {}).get("chat"): return ( None, None, - f"Failed to load NewLLMConfig with id {config_id}", + f"Failed to load chat model with id {config_id}", ) + model_string, litellm_kwargs = to_litellm(model.connection, model.model_id) + agent_config = _agent_config_from_resolved( + config_id=config_id, + config_name=model.display_name or model.model_id, + provider=model.connection.litellm_provider or "", + model_name=model.model_id, + api_key=model.connection.api_key, + api_base=model.connection.base_url, + litellm_params=(model.connection.extra or {}).get("litellm_params"), + supports_image_input=bool((model.capabilities or {}).get("vision")), + billing_tier="free", + ) return ( - create_chat_litellm_from_agent_config(loaded_agent_config), - loaded_agent_config, + SanitizedChatLiteLLM(model=model_string, **litellm_kwargs), + agent_config, None, ) - loaded_llm_config = load_global_llm_config_by_id(config_id) - if not loaded_llm_config: - return None, None, f"Failed to load LLM config with id {config_id}" - return ( - create_chat_litellm_from_config(loaded_llm_config), - AgentConfig.from_yaml_config(loaded_llm_config), + global_model = next((m for m in config.GLOBAL_MODELS if m.get("id") == config_id), None) + if not global_model or not (global_model.get("capabilities") or {}).get("chat"): + return None, None, f"Failed to load global chat model with id {config_id}" + global_connection = next( + ( + c + for c in config.GLOBAL_CONNECTIONS + if c.get("id") == global_model.get("connection_id") + ), + None, + ) + if not global_connection: + return None, None, f"Failed to load global connection for model {config_id}" + model_string, litellm_kwargs = to_litellm(global_connection, global_model["model_id"]) + agent_config = _agent_config_from_resolved( + config_id=config_id, + config_name=global_model.get("display_name") or global_model.get("model_id"), + provider=global_connection.get("litellm_provider") or "", + model_name=global_model["model_id"], + api_key=global_connection.get("api_key"), + api_base=global_connection.get("base_url"), + litellm_params=(global_connection.get("extra") or {}).get("litellm_params"), + supports_image_input=bool((global_model.get("capabilities") or {}).get("vision")), + billing_tier=str(global_model.get("billing_tier", "free")).lower(), + ) + return ( + SanitizedChatLiteLLM(model=model_string, **litellm_kwargs), + agent_config, None, ) diff --git a/surfsense_backend/tests/unit/services/test_auto_model_pin_service.py b/surfsense_backend/tests/unit/services/test_auto_model_pin_service.py index d1af29aeb..0af41a7ee 100644 --- a/surfsense_backend/tests/unit/services/test_auto_model_pin_service.py +++ b/surfsense_backend/tests/unit/services/test_auto_model_pin_service.py @@ -75,10 +75,10 @@ async def test_auto_first_turn_pins_one_model(monkeypatch): config, "GLOBAL_LLM_CONFIGS", [ - {"id": -2, "provider": "OPENAI", "model_name": "gpt-free", "api_key": "k1"}, + {"id": -2, "litellm_provider": "openai", "model_name": "gpt-free", "api_key": "k1"}, { "id": -1, - "provider": "OPENAI", + "litellm_provider": "openai", "model_name": "gpt-prem", "api_key": "k2", "billing_tier": "premium", @@ -117,7 +117,7 @@ async def test_premium_eligible_auto_prefers_premium_over_free(monkeypatch): [ { "id": -2, - "provider": "OPENAI", + "litellm_provider": "openai", "model_name": "gpt-free", "api_key": "k1", "billing_tier": "free", @@ -125,7 +125,7 @@ async def test_premium_eligible_auto_prefers_premium_over_free(monkeypatch): }, { "id": -1, - "provider": "OPENAI", + "litellm_provider": "openai", "model_name": "gpt-prem", "api_key": "k2", "billing_tier": "premium", @@ -164,7 +164,7 @@ async def test_premium_eligible_auto_prefers_azure_gpt_5_4(monkeypatch): [ { "id": -1, - "provider": "AZURE_OPENAI", + "litellm_provider": "azure", "model_name": "gpt-5.1", "api_key": "k1", "billing_tier": "premium", @@ -173,7 +173,7 @@ async def test_premium_eligible_auto_prefers_azure_gpt_5_4(monkeypatch): }, { "id": -2, - "provider": "AZURE_OPENAI", + "litellm_provider": "azure", "model_name": "gpt-5.4", "api_key": "k2", "billing_tier": "premium", @@ -182,7 +182,7 @@ async def test_premium_eligible_auto_prefers_azure_gpt_5_4(monkeypatch): }, { "id": -3, - "provider": "OPENROUTER", + "litellm_provider": "openrouter", "model_name": "openai/gpt-5.4", "api_key": "k3", "billing_tier": "premium", @@ -222,7 +222,7 @@ async def test_next_turn_reuses_existing_pin(monkeypatch): [ { "id": -1, - "provider": "OPENAI", + "litellm_provider": "openai", "model_name": "gpt-prem", "api_key": "k2", "billing_tier": "premium", @@ -263,7 +263,7 @@ async def test_premium_eligible_auto_can_pin_premium(monkeypatch): [ { "id": -1, - "provider": "OPENAI", + "litellm_provider": "openai", "model_name": "gpt-prem", "api_key": "k2", "billing_tier": "premium", @@ -301,14 +301,14 @@ async def test_premium_ineligible_auto_pins_free_only(monkeypatch): [ { "id": -2, - "provider": "OPENAI", + "litellm_provider": "openai", "model_name": "gpt-free", "api_key": "k1", "billing_tier": "free", }, { "id": -1, - "provider": "OPENAI", + "litellm_provider": "openai", "model_name": "gpt-prem", "api_key": "k2", "billing_tier": "premium", @@ -346,14 +346,14 @@ async def test_pinned_premium_stays_premium_after_quota_exhaustion(monkeypatch): [ { "id": -2, - "provider": "OPENAI", + "litellm_provider": "openai", "model_name": "gpt-free", "api_key": "k1", "billing_tier": "free", }, { "id": -1, - "provider": "OPENAI", + "litellm_provider": "openai", "model_name": "gpt-prem", "api_key": "k2", "billing_tier": "premium", @@ -391,14 +391,14 @@ async def test_force_repin_free_switches_auto_premium_pin_to_free(monkeypatch): [ { "id": -2, - "provider": "OPENAI", + "litellm_provider": "openai", "model_name": "gpt-free", "api_key": "k1", "billing_tier": "free", }, { "id": -1, - "provider": "OPENAI", + "litellm_provider": "openai", "model_name": "gpt-prem", "api_key": "k2", "billing_tier": "premium", @@ -437,7 +437,7 @@ async def test_explicit_user_model_change_clears_pin(monkeypatch): config, "GLOBAL_LLM_CONFIGS", [ - {"id": -2, "provider": "OPENAI", "model_name": "gpt-free", "api_key": "k1"}, + {"id": -2, "litellm_provider": "openai", "model_name": "gpt-free", "api_key": "k1"}, ], ) @@ -462,7 +462,7 @@ async def test_invalid_pinned_config_repairs_with_new_pin(monkeypatch): config, "GLOBAL_LLM_CONFIGS", [ - {"id": -2, "provider": "OPENAI", "model_name": "gpt-free", "api_key": "k1"}, + {"id": -2, "litellm_provider": "openai", "model_name": "gpt-free", "api_key": "k1"}, ], ) @@ -504,7 +504,7 @@ async def test_health_gated_config_is_excluded_from_selection(monkeypatch): [ { "id": -1, - "provider": "OPENROUTER", + "litellm_provider": "openrouter", "model_name": "venice/dead-model", "api_key": "k1", "billing_tier": "free", @@ -514,7 +514,7 @@ async def test_health_gated_config_is_excluded_from_selection(monkeypatch): }, { "id": -2, - "provider": "OPENROUTER", + "litellm_provider": "openrouter", "model_name": "google/gemini-flash", "api_key": "k1", "billing_tier": "free", @@ -556,7 +556,7 @@ async def test_tier_a_locks_first_premium_user_skips_or(monkeypatch): [ { "id": -1, - "provider": "AZURE_OPENAI", + "litellm_provider": "azure", "model_name": "gpt-5", "api_key": "k-yaml", "billing_tier": "premium", @@ -566,7 +566,7 @@ async def test_tier_a_locks_first_premium_user_skips_or(monkeypatch): }, { "id": -2, - "provider": "OPENROUTER", + "litellm_provider": "openrouter", "model_name": "openai/gpt-5", "api_key": "k-or", "billing_tier": "premium", @@ -608,7 +608,7 @@ async def test_tier_a_falls_through_to_or_when_a_pool_empty_for_user(monkeypatch [ { "id": -1, - "provider": "AZURE_OPENAI", + "litellm_provider": "azure", "model_name": "gpt-5", "api_key": "k-yaml", "billing_tier": "premium", @@ -618,7 +618,7 @@ async def test_tier_a_falls_through_to_or_when_a_pool_empty_for_user(monkeypatch }, { "id": -2, - "provider": "OPENROUTER", + "litellm_provider": "openrouter", "model_name": "google/gemini-flash:free", "api_key": "k-or", "billing_tier": "free", @@ -656,7 +656,7 @@ async def test_top_k_picks_only_high_score_models(monkeypatch): high_score_cfgs = [ { "id": -i, - "provider": "AZURE_OPENAI", + "litellm_provider": "azure", "model_name": f"gpt-x-{i}", "api_key": "k", "billing_tier": "premium", @@ -668,7 +668,7 @@ async def test_top_k_picks_only_high_score_models(monkeypatch): ] low_score_trap = { "id": -99, - "provider": "AZURE_OPENAI", + "litellm_provider": "azure", "model_name": "tiny-legacy", "api_key": "k", "billing_tier": "premium", @@ -729,7 +729,7 @@ async def test_pin_reuse_survives_health_gating_for_existing_pin(monkeypatch): [ { "id": -1, - "provider": "OPENROUTER", + "litellm_provider": "openrouter", "model_name": "venice/dead-model", "api_key": "k", "billing_tier": "premium", @@ -739,7 +739,7 @@ async def test_pin_reuse_survives_health_gating_for_existing_pin(monkeypatch): }, { "id": -2, - "provider": "AZURE_OPENAI", + "litellm_provider": "azure", "model_name": "gpt-5", "api_key": "k", "billing_tier": "premium", @@ -781,7 +781,7 @@ async def test_pin_reuse_regression_existing_healthy_pin(monkeypatch): [ { "id": -1, - "provider": "AZURE_OPENAI", + "litellm_provider": "azure", "model_name": "gpt-5", "api_key": "k", "billing_tier": "premium", @@ -791,7 +791,7 @@ async def test_pin_reuse_regression_existing_healthy_pin(monkeypatch): }, { "id": -2, - "provider": "AZURE_OPENAI", + "litellm_provider": "azure", "model_name": "gpt-5-pro", "api_key": "k", "billing_tier": "premium", @@ -839,7 +839,7 @@ async def test_runtime_cooled_down_pin_is_not_reused(monkeypatch): [ { "id": -1, - "provider": "OPENROUTER", + "litellm_provider": "openrouter", "model_name": "google/gemma-4-26b-a4b-it:free", "api_key": "k", "billing_tier": "free", @@ -849,7 +849,7 @@ async def test_runtime_cooled_down_pin_is_not_reused(monkeypatch): }, { "id": -2, - "provider": "OPENROUTER", + "litellm_provider": "openrouter", "model_name": "google/gemini-2.5-flash:free", "api_key": "k", "billing_tier": "free", @@ -892,7 +892,7 @@ async def test_clearing_runtime_cooldown_restores_pin_reuse(monkeypatch): [ { "id": -1, - "provider": "OPENROUTER", + "litellm_provider": "openrouter", "model_name": "google/gemma-4-26b-a4b-it:free", "api_key": "k", "billing_tier": "free", @@ -937,7 +937,7 @@ async def test_auto_pin_repin_excludes_previous_config_on_runtime_retry(monkeypa [ { "id": -1, - "provider": "OPENROUTER", + "litellm_provider": "openrouter", "model_name": "google/gemma-4-26b-a4b-it:free", "api_key": "k", "billing_tier": "free", @@ -947,7 +947,7 @@ async def test_auto_pin_repin_excludes_previous_config_on_runtime_retry(monkeypa }, { "id": -2, - "provider": "OPENROUTER", + "litellm_provider": "openrouter", "model_name": "google/gemini-2.5-flash:free", "api_key": "k", "billing_tier": "free", diff --git a/surfsense_backend/tests/unit/services/test_auto_pin_image_aware.py b/surfsense_backend/tests/unit/services/test_auto_pin_image_aware.py index 0e19b80e4..e267d59ba 100644 --- a/surfsense_backend/tests/unit/services/test_auto_pin_image_aware.py +++ b/surfsense_backend/tests/unit/services/test_auto_pin_image_aware.py @@ -74,7 +74,7 @@ def _thread(*, pinned: int | None = None): def _vision_cfg(id_: int, *, tier: str = "free", quality: int = 80) -> dict: return { "id": id_, - "provider": "OPENAI", + "litellm_provider": "openai", "model_name": f"vision-{id_}", "api_key": "k", "billing_tier": tier, @@ -87,7 +87,7 @@ def _vision_cfg(id_: int, *, tier: str = "free", quality: int = 80) -> dict: def _text_only_cfg(id_: int, *, tier: str = "free", quality: int = 90) -> dict: return { "id": id_, - "provider": "OPENAI", + "litellm_provider": "openai", "model_name": f"text-{id_}", "api_key": "k", "billing_tier": tier, @@ -261,7 +261,7 @@ async def test_image_turn_unannotated_cfg_resolves_via_helper(monkeypatch): session = _FakeSession(_thread()) cfg_unannotated_vision = { "id": -2, - "provider": "OPENAI", + "litellm_provider": "openai", "model_name": "gpt-4o", # known vision model in LiteLLM map "api_key": "k", "billing_tier": "free", diff --git a/surfsense_backend/tests/unit/services/test_llm_router_pool_filter.py b/surfsense_backend/tests/unit/services/test_llm_router_pool_filter.py index c309ff881..efe906ac0 100644 --- a/surfsense_backend/tests/unit/services/test_llm_router_pool_filter.py +++ b/surfsense_backend/tests/unit/services/test_llm_router_pool_filter.py @@ -25,10 +25,10 @@ def _fake_yaml_config( return { "id": id, "name": f"yaml-{id}", - "provider": "OPENAI", + "litellm_provider": "openai", "model_name": model_name, "api_key": "sk-test", - "api_base": "", + "api_base": "https://api.openai.com/v1", "billing_tier": billing_tier, "rpm": 100, "tpm": 100_000, @@ -54,10 +54,10 @@ def _fake_openrouter_config( return { "id": id, "name": f"or-{id}", - "provider": "OPENROUTER", + "litellm_provider": "openrouter", "model_name": model_name, "api_key": "sk-or-test", - "api_base": "", + "api_base": "https://openrouter.ai/api/v1", "billing_tier": billing_tier, "rpm": 20 if billing_tier == "free" else 200, "tpm": 100_000 if billing_tier == "free" else 1_000_000, diff --git a/surfsense_backend/tests/unit/services/test_or_health_enrichment.py b/surfsense_backend/tests/unit/services/test_or_health_enrichment.py index 1c74aa928..b4b6618a4 100644 --- a/surfsense_backend/tests/unit/services/test_or_health_enrichment.py +++ b/surfsense_backend/tests/unit/services/test_or_health_enrichment.py @@ -25,7 +25,7 @@ def _or_cfg( ) -> dict: return { "id": cid, - "provider": "OPENROUTER", + "litellm_provider": "openrouter", "model_name": model_name, "billing_tier": tier, "auto_pin_tier": "B" if tier == "premium" else "C", @@ -144,7 +144,7 @@ async def test_enrich_health_only_touches_or_provider(monkeypatch): """YAML cfgs that aren't OPENROUTER must be skipped entirely.""" yaml_cfg = { "id": -1, - "provider": "AZURE_OPENAI", + "litellm_provider": "azure", "model_name": "gpt-5", "billing_tier": "premium", "auto_pin_tier": "A", @@ -313,7 +313,7 @@ async def test_enrich_health_no_or_cfgs_is_noop(monkeypatch): """When the catalogue has no OR cfgs at all, no HTTP calls fire.""" yaml_cfg: dict[str, Any] = { "id": -1, - "provider": "AZURE_OPENAI", + "litellm_provider": "azure", "model_name": "gpt-5", "billing_tier": "premium", } diff --git a/surfsense_backend/tests/unit/tasks/test_stream_new_chat_image_safety_net.py b/surfsense_backend/tests/unit/tasks/test_stream_new_chat_image_safety_net.py index 792d059b0..6bfc72bf3 100644 --- a/surfsense_backend/tests/unit/tasks/test_stream_new_chat_image_safety_net.py +++ b/surfsense_backend/tests/unit/tasks/test_stream_new_chat_image_safety_net.py @@ -35,7 +35,7 @@ def test_safety_net_does_not_fire_for_azure_gpt_4o(): it text-only.""" assert ( is_known_text_only_chat_model( - provider="AZURE_OPENAI", + litellm_provider="azure", model_name="my-azure-deployment", base_model="gpt-4o", ) @@ -49,7 +49,7 @@ def test_safety_net_does_not_fire_for_unknown_model(): LiteLLM doesn't know about must flow through to the provider.""" assert ( is_known_text_only_chat_model( - provider="CUSTOM", + litellm_provider="custom", custom_provider="brand_new_proxy", model_name="brand-new-model-x9", ) @@ -69,7 +69,7 @@ def test_safety_net_does_not_fire_when_lookup_raises(monkeypatch): assert ( is_known_text_only_chat_model( - provider="OPENAI", + litellm_provider="openai", model_name="gpt-4o", ) is False @@ -88,7 +88,7 @@ def test_safety_net_fires_only_on_explicit_false(monkeypatch): monkeypatch.setattr(pc.litellm, "get_model_info", _info_explicit_false) assert ( is_known_text_only_chat_model( - provider="OPENAI", + litellm_provider="openai", model_name="text-only-stub", ) is True @@ -100,7 +100,7 @@ def test_safety_net_fires_only_on_explicit_false(monkeypatch): monkeypatch.setattr(pc.litellm, "get_model_info", _info_true) assert ( is_known_text_only_chat_model( - provider="OPENAI", + litellm_provider="openai", model_name="vision-stub", ) is False @@ -112,7 +112,7 @@ def test_safety_net_fires_only_on_explicit_false(monkeypatch): monkeypatch.setattr(pc.litellm, "get_model_info", _info_missing) assert ( is_known_text_only_chat_model( - provider="OPENAI", + litellm_provider="openai", model_name="missing-key-stub", ) is False