feat(chat): route models by provider capabilities

This commit is contained in:
Anish Sarkar 2026-06-11 18:22:23 +05:30
parent 8f20a32571
commit c28c4f5785
18 changed files with 429 additions and 319 deletions

View file

@ -2,9 +2,9 @@
LLM configuration utilities for SurfSense agents.
This module provides functions for loading LLM configurations from:
1. Auto mode (ID 0) - Uses LiteLLM Router for load balancing
1. Auto mode (ID 0) - Resolved by callers to a concrete model-connection model
2. YAML files (global configs with negative IDs)
3. Database NewLLMConfig table (user-created configs with positive IDs)
3. Database model-connections table (user-created configs with positive IDs)
It also provides utilities for creating ChatLiteLLM instances and
managing prompt configurations.
@ -33,9 +33,7 @@ from app.agents.chat.runtime.prompt_caching import (
from app.services.llm_router_service import (
AUTO_MODE_ID,
ChatLiteLLMRouter,
LLMRouterService,
_sanitize_content,
get_auto_mode_llm,
is_auto_mode,
)
@ -92,14 +90,6 @@ class SanitizedChatLiteLLM(ChatLiteLLM):
yield chunk
# Re-exported under the historical name ``PROVIDER_MAP``. Source of truth lives
# in provider_capabilities so the YAML loader can resolve prefixes during
# app.config init without importing the agent/tools tree.
from app.services.provider_capabilities import ( # noqa: E402
_PROVIDER_PREFIX_MAP as PROVIDER_MAP,
)
def _attach_model_profile(llm: ChatLiteLLM, model_string: str) -> None:
"""Attach a ``profile`` dict to ChatLiteLLM with model context metadata."""
try:
@ -122,7 +112,8 @@ class AgentConfig:
Complete configuration for the SurfSense agent.
This combines LLM settings with prompt configuration from NewLLMConfig.
Supports Auto mode (ID 0) which uses LiteLLM Router for load balancing.
Supports Auto mode metadata (ID 0). Runtime callers must resolve Auto to
a concrete global or BYOK model before constructing ChatLiteLLM.
"""
# LLM Model Settings
@ -219,7 +210,7 @@ class AgentConfig:
# BYOK rows have no curated flag; ask LiteLLM (default-allow on
# unknown). The streaming safety net still blocks explicit text-only.
supports_image_input=derive_supports_image_input(
provider=provider_value,
litellm_provider=provider_value.lower(),
model_name=config.model_name,
base_model=base_model,
custom_provider=config.custom_provider,
@ -238,7 +229,7 @@ class AgentConfig:
system_instructions = yaml_config.get("system_instructions", "")
provider = yaml_config.get("provider", "").upper()
provider = yaml_config.get("litellm_provider", "")
model_name = yaml_config.get("model_name", "")
custom_provider = yaml_config.get("custom_provider")
litellm_params = yaml_config.get("litellm_params") or {}
@ -254,7 +245,7 @@ class AgentConfig:
supports_image_input = bool(yaml_config.get("supports_image_input"))
else:
supports_image_input = derive_supports_image_input(
provider=provider,
litellm_provider=provider,
model_name=model_name,
base_model=base_model,
custom_provider=custom_provider,
@ -383,9 +374,6 @@ async def load_agent_config(
) -> "AgentConfig | None":
"""Main config loader: id 0 -> Auto mode; negative -> YAML; positive -> DB."""
if is_auto_mode(config_id):
if not LLMRouterService.is_initialized():
print("Error: Auto mode requested but LLM Router not initialized")
return None
return AgentConfig.from_auto_mode()
if config_id < 0:
@ -408,9 +396,8 @@ def create_chat_litellm_from_config(llm_config: dict) -> ChatLiteLLM | None:
if llm_config.get("custom_provider"):
model_string = f"{llm_config['custom_provider']}/{llm_config['model_name']}"
else:
provider = llm_config.get("provider", "").upper()
provider_prefix = PROVIDER_MAP.get(provider, provider.lower())
model_string = f"{provider_prefix}/{llm_config['model_name']}"
litellm_provider = llm_config.get("litellm_provider", "openai")
model_string = f"{litellm_provider}/{llm_config['model_name']}"
litellm_kwargs = {
"model": model_string,
@ -433,29 +420,15 @@ def create_chat_litellm_from_config(llm_config: dict) -> ChatLiteLLM | None:
def create_chat_litellm_from_agent_config(
agent_config: AgentConfig,
) -> ChatLiteLLM | ChatLiteLLMRouter | None:
"""Create a ChatLiteLLM (or, for Auto mode, a load-balancing router) from config."""
"""Create a ChatLiteLLM from an already resolved concrete model config."""
if agent_config.is_auto_mode:
if not LLMRouterService.is_initialized():
print("Error: Auto mode requested but LLM Router not initialized")
return None
try:
router_llm = get_auto_mode_llm()
if router_llm is not None:
# Universal injection points only: auto-mode fans out across
# providers, so provider-specific kwargs have no known target.
apply_litellm_prompt_caching(router_llm, agent_config=agent_config)
return router_llm
except Exception as e:
print(f"Error creating ChatLiteLLMRouter: {e}")
return None
print("Error: Auto mode must be resolved to a concrete model before LLM creation")
return None
if agent_config.custom_provider:
model_string = f"{agent_config.custom_provider}/{agent_config.model_name}"
else:
provider_prefix = PROVIDER_MAP.get(
agent_config.provider, agent_config.provider.lower()
)
model_string = f"{provider_prefix}/{agent_config.model_name}"
model_string = f"{agent_config.provider}/{agent_config.model_name}"
litellm_kwargs = {
"model": model_string,

View file

@ -132,7 +132,7 @@ async def list_anonymous_models():
id=cfg.get("id", 0),
name=cfg.get("name", ""),
description=cfg.get("description"),
provider=cfg.get("provider", ""),
provider=cfg.get("litellm_provider", ""),
model_name=cfg.get("model_name", ""),
billing_tier=cfg.get("billing_tier", "free"),
is_premium=cfg.get("billing_tier", "free") == "premium",
@ -161,7 +161,7 @@ async def get_anonymous_model(slug: str):
id=cfg.get("id", 0),
name=cfg.get("name", ""),
description=cfg.get("description"),
provider=cfg.get("provider", ""),
provider=cfg.get("litellm_provider", ""),
model_name=cfg.get("model_name", ""),
billing_tier=cfg.get("billing_tier", "free"),
is_premium=cfg.get("billing_tier", "free") == "premium",

View file

@ -96,7 +96,7 @@ async def get_global_vision_llm_configs(
"id": cfg.get("id"),
"name": cfg.get("name"),
"description": cfg.get("description"),
"provider": cfg.get("provider"),
"provider": cfg.get("litellm_provider"),
"custom_provider": cfg.get("custom_provider"),
"model_name": cfg.get("model_name"),
"api_base": cfg.get("api_base") or None,

View file

@ -23,9 +23,10 @@ from uuid import UUID
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from app.config import config
from app.db import NewChatThread
from app.db import Connection, Model, NewChatThread
from app.services.quality_score import _QUALITY_TOP_K
from app.services.token_quota_service import TokenQuotaService
@ -61,11 +62,20 @@ def _is_usable_global_config(cfg: dict) -> bool:
return bool(
cfg.get("id") is not None
and cfg.get("model_name")
and cfg.get("provider")
and cfg.get("litellm_provider")
and cfg.get("api_key")
)
def _has_capability(model: dict | Model, capability: str) -> bool:
caps = (
model.get("capabilities", {})
if isinstance(model, dict)
else model.capabilities or {}
)
return bool(caps.get(capability))
def _prune_runtime_cooldowns(now_ts: float | None = None) -> None:
now = time.time() if now_ts is None else now_ts
stale = [cid for cid, until in _runtime_cooldown_until.items() if until <= now]
@ -186,15 +196,19 @@ def _cfg_supports_image_input(cfg: dict) -> bool:
else None
)
return derive_supports_image_input(
provider=cfg.get("provider"),
litellm_provider=cfg.get("litellm_provider"),
model_name=cfg.get("model_name"),
base_model=base_model,
custom_provider=cfg.get("custom_provider"),
)
def _global_candidates(*, requires_image_input: bool = False) -> list[dict]:
"""Return Auto-eligible global cfgs.
def _global_candidates(
*,
capability: str = "chat",
requires_image_input: bool = False,
) -> list[dict]:
"""Return Auto-eligible global virtual models.
Drops cfgs flagged ``health_gated`` (best non-null OpenRouter uptime
below ``_HEALTH_GATE_UPTIME_PCT``) so chronically broken providers
@ -205,17 +219,135 @@ def _global_candidates(*, requires_image_input: bool = False) -> list[dict]:
filters out configs whose ``supports_image_input`` resolves to False
so a text-only deployment can't be pinned for an image request.
"""
candidates = [
cfg
connection_by_id = {
int(conn.get("id")): conn
for conn in config.GLOBAL_CONNECTIONS
if conn.get("id") is not None
}
config_by_model_name = {
cfg.get("model_name"): cfg
for cfg in config.GLOBAL_LLM_CONFIGS
if _is_usable_global_config(cfg)
and not cfg.get("health_gated")
and not _is_runtime_cooled_down(int(cfg.get("id", 0)))
and (not requires_image_input or _cfg_supports_image_input(cfg))
]
}
candidates: list[dict] = []
for model in config.GLOBAL_MODELS:
model_id = int(model.get("id", 0))
if model_id >= 0 or _is_runtime_cooled_down(model_id):
continue
if not _has_capability(model, capability):
continue
cfg = config_by_model_name.get(model.get("model_id")) or {}
if cfg.get("health_gated"):
continue
if requires_image_input and not _has_capability(model, "vision"):
continue
if requires_image_input and cfg and not _cfg_supports_image_input(cfg):
continue
connection = connection_by_id.get(int(model.get("connection_id", 0)))
if not connection:
continue
catalog = model.get("catalog") or {}
candidates.append(
{
"id": model_id,
"model_id": model.get("model_id"),
"source": "global",
"connection": connection,
"capabilities": model.get("capabilities") or {},
"billing_tier": model.get("billing_tier", "free"),
"litellm_provider": connection.get("litellm_provider"),
"model_name": model.get("model_id"),
"auto_pin_tier": catalog.get("auto_pin_tier")
or cfg.get("auto_pin_tier")
or "A",
"quality_score": catalog.get("quality_score")
or cfg.get("quality_score")
or cfg.get("quality_score_static")
or 50,
}
)
return sorted(candidates, key=lambda c: int(c.get("id", 0)))
async def _db_candidates(
session: AsyncSession,
*,
search_space_id: int,
user_id: str | UUID | None,
capability: str,
requires_image_input: bool = False,
) -> list[dict]:
parsed_user_id = _to_uuid(user_id)
stmt = (
select(Model)
.options(selectinload(Model.connection))
.join(Connection, Model.connection_id == Connection.id)
.where(Model.enabled.is_(True), Connection.enabled.is_(True))
)
result = await session.execute(stmt)
candidates: list[dict] = []
for model in result.scalars().all():
conn = model.connection
if not conn:
continue
if conn.search_space_id is not None and conn.search_space_id != search_space_id:
continue
if conn.user_id is not None and parsed_user_id is not None and conn.user_id != parsed_user_id:
continue
if conn.user_id is not None and parsed_user_id is None:
continue
if not _has_capability(model, capability):
continue
if requires_image_input and not _has_capability(model, "vision"):
continue
model_id = int(model.id)
if _is_runtime_cooled_down(model_id):
continue
catalog = model.catalog or {}
candidates.append(
{
"id": model_id,
"model_id": model.model_id,
"source": "db",
"connection": conn,
"capabilities": model.capabilities or {},
"billing_tier": "byok",
"litellm_provider": conn.litellm_provider,
"model_name": model.model_id,
"auto_pin_tier": catalog.get("auto_pin_tier") or "BYOK",
"quality_score": catalog.get("quality_score") or 75,
}
)
return sorted(candidates, key=lambda c: int(c.get("id", 0)))
async def auto_model_candidates(
session: AsyncSession,
*,
search_space_id: int,
user_id: str | UUID | None,
capability: str,
requires_image_input: bool = False,
exclude_model_ids: set[int] | None = None,
) -> list[dict]:
excluded_ids = {int(mid) for mid in (exclude_model_ids or set())}
db_candidates = await _db_candidates(
session,
search_space_id=search_space_id,
user_id=user_id,
capability=capability,
requires_image_input=requires_image_input,
)
candidates = [
*_global_candidates(
capability=capability,
requires_image_input=requires_image_input,
),
*db_candidates,
]
return [c for c in candidates if int(c.get("id", 0)) not in excluded_ids]
def _tier_of(cfg: dict) -> str:
return str(cfg.get("billing_tier", "free")).lower()
@ -223,8 +355,9 @@ def _tier_of(cfg: dict) -> str:
def _is_preferred_premium_auto_config(cfg: dict) -> bool:
"""Return True for the operator-preferred premium Auto model."""
return (
_tier_of(cfg) == "premium"
and str(cfg.get("provider", "")).upper() == "AZURE_OPENAI"
cfg.get("source") == "global"
and _tier_of(cfg) == "premium"
and str(cfg.get("litellm_provider", "")).lower() == "azure"
and str(cfg.get("model_name", "")).lower() == "gpt-5.4"
)
@ -251,6 +384,11 @@ def _select_pin(eligible: list[dict], thread_id: int) -> tuple[dict, int]:
return top_k[idx], len(top_k)
def choose_auto_model_candidate(candidates: list[dict], seed_id: int) -> dict:
selected, _ = _select_pin(candidates, seed_id)
return selected
def _to_uuid(user_id: str | UUID | None) -> UUID | None:
if user_id is None:
return None
@ -326,20 +464,23 @@ async def resolve_or_get_pinned_llm_config_id(
)
excluded_ids = {int(cid) for cid in (exclude_config_ids or set())}
candidates = [
c
for c in _global_candidates(requires_image_input=requires_image_input)
if int(c.get("id", 0)) not in excluded_ids
]
candidates = await auto_model_candidates(
session,
search_space_id=search_space_id,
user_id=user_id,
capability="chat",
requires_image_input=requires_image_input,
exclude_model_ids=excluded_ids,
)
if not candidates:
if requires_image_input:
# Distinguish the "no vision-capable cfg" case from generic
# "no usable cfg" so the streaming task can map this to the
# MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT SSE error.
raise ValueError(
"No vision-capable global LLM configs are available for Auto mode"
"No vision-capable LLM models are available for Auto mode"
)
raise ValueError("No usable global LLM configs are available for Auto mode")
raise ValueError("No usable LLM models are available for Auto mode")
candidate_by_id = {int(c["id"]): c for c in candidates}
# Reuse an existing valid pin without re-checking current quota (no silent
@ -379,24 +520,13 @@ async def resolve_or_get_pinned_llm_config_id(
# log that explicitly so operators can correlate the re-pin with
# the user's image attachment instead of suspecting a cooldown.
if requires_image_input:
try:
pinned_global = next(
c
for c in config.GLOBAL_LLM_CONFIGS
if int(c.get("id", 0)) == int(pinned_id)
)
except StopIteration:
pinned_global = None
if pinned_global is not None and not _cfg_supports_image_input(
pinned_global
):
logger.info(
"auto_pin_repinned_for_image thread_id=%s search_space_id=%s "
"previous_config_id=%s",
thread_id,
search_space_id,
pinned_id,
)
logger.info(
"auto_pin_repinned_for_image thread_id=%s search_space_id=%s "
"previous_config_id=%s",
thread_id,
search_space_id,
pinned_id,
)
logger.info(
"auto_pin_invalid thread_id=%s search_space_id=%s pinned_config_id=%s",
thread_id,

View file

@ -30,11 +30,7 @@ from litellm.exceptions import (
)
from pydantic import Field
from app.services.model_resolver import (
NATIVE_PROVIDER_PREFIX,
native_connection_from_config,
to_litellm,
)
from app.services.model_resolver import native_connection_from_config, to_litellm
from app.utils.perf import get_perf_logger
litellm.json_logs = False
@ -101,10 +97,6 @@ def _sanitize_content(content: Any) -> Any:
# Special ID for Auto mode - uses router for load balancing
AUTO_MODE_ID = 0
# Historical export kept for callers that still import PROVIDER_MAP.
PROVIDER_MAP = NATIVE_PROVIDER_PREFIX
class LLMRouterService:
"""
Singleton service for managing LiteLLM Router.

View file

@ -10,13 +10,11 @@ from sqlalchemy.orm import selectinload
from app.config import config
from app.db import Model, SearchSpace
from app.services.llm_router_service import (
AUTO_MODE_ID,
ChatLiteLLMRouter,
LLMRouterService,
get_auto_mode_llm,
is_auto_mode,
from app.services.auto_model_pin_service import (
auto_model_candidates,
choose_auto_model_candidate,
)
from app.services.llm_router_service import AUTO_MODE_ID, ChatLiteLLMRouter, is_auto_mode
from app.services.model_resolver import native_connection_from_config, to_litellm
from app.services.token_tracking_service import token_tracker
@ -78,7 +76,7 @@ def _legacy_config_connection(
api_version: str | None = None,
) -> tuple[str, dict]:
cfg = {
"provider": provider,
"litellm_provider": provider.lower(),
"model_name": model_name,
"api_key": api_key,
"api_base": api_base,
@ -325,23 +323,21 @@ async def get_search_space_llm_instance(
logger.error(f"No {role} LLM configured for search space {search_space_id}")
return None
# Check for Auto mode (ID 0) - use router for load balancing
# Auto mode resolves to one concrete global or BYOK model from the
# unified model-connections catalog.
if is_auto_mode(llm_config_id):
if not LLMRouterService.is_initialized():
logger.error(
"Auto mode requested but LLM Router not initialized. "
"Ensure global_llm_config.yaml exists with valid configs."
)
return None
try:
logger.debug(
f"Using Auto mode (LLM Router) for search space {search_space_id}, role {role}"
)
return get_auto_mode_llm(streaming=not disable_streaming)
except Exception as e:
logger.error(f"Failed to create ChatLiteLLMRouter: {e}")
candidates = await auto_model_candidates(
session,
search_space_id=search_space_id,
user_id=search_space.user_id,
capability="chat",
)
if not candidates:
logger.error("No chat-capable models available for Auto mode")
return None
llm_config_id = int(
choose_auto_model_candidate(candidates, search_space_id)["id"]
)
# Check if this is a global virtual model (negative ID)
if llm_config_id < 0:
@ -414,7 +410,7 @@ async def get_vision_llm(
"""Get the search space's vision LLM instance for screenshot analysis.
Resolves from the new connection/model role bindings:
- Auto mode (ID 0): VisionLLMRouterService
- Auto mode (ID 0): unified global/BYOK model candidate selection
- Global (negative ID): virtual GLOBAL models from YAML
- DB (positive ID): Model + Connection tables
@ -424,10 +420,7 @@ async def get_vision_llm(
unwrapped they don't consume premium credit (issue M).
"""
from app.services.quota_checked_vision_llm import QuotaCheckedVisionLLM
from app.services.vision_llm_router_service import (
VisionLLMRouterService,
is_vision_auto_mode,
)
from app.services.vision_llm_router_service import is_vision_auto_mode
try:
result = await session.execute(
@ -476,26 +469,16 @@ async def get_vision_llm(
return None
if is_vision_auto_mode(config_id):
if not VisionLLMRouterService.is_initialized():
logger.error(
"Vision Auto mode requested but Vision LLM Router not initialized"
)
return None
try:
# Auto mode is currently treated as free at the wrapper
# level — the underlying router can dispatch to either
# premium or free YAML configs but routing decisions are
# opaque. If/when we want to bill Auto-routed vision
# calls we'd need to thread the resolved deployment's
# billing_tier back from the router. For now we keep
# parity with chat Auto, which also doesn't pre-classify.
return ChatLiteLLMRouter(
router=VisionLLMRouterService.get_router(),
streaming=True,
)
except Exception as e:
logger.error(f"Failed to create vision ChatLiteLLMRouter: {e}")
candidates = await auto_model_candidates(
session,
search_space_id=search_space_id,
user_id=owner_user_id,
capability="vision",
)
if not candidates:
logger.error("No vision-capable models available for Auto mode")
return None
config_id = int(choose_auto_model_candidate(candidates, search_space_id)["id"])
if config_id < 0:
global_model = get_global_model(config_id)

View file

@ -154,19 +154,19 @@ def _process_models(raw_models: list[dict]) -> list[dict]:
}
)
# 2) Emit for the native provider when we have a mapping
native_provider = OPENROUTER_SLUG_TO_PROVIDER.get(provider_slug)
if native_provider:
# 2) Emit for the direct provider when we have a mapping
direct_provider = OPENROUTER_SLUG_TO_PROVIDER.get(provider_slug)
if direct_provider:
# Google's Gemini API only serves gemini-* models.
# Open-source models like gemma-* are NOT available through it.
if native_provider == "GOOGLE" and not model_name.startswith("gemini-"):
if direct_provider == "GOOGLE" and not model_name.startswith("gemini-"):
continue
processed.append(
{
"value": model_name,
"label": name,
"provider": native_provider,
"provider": direct_provider,
"context_window": context_window,
}
)

View file

@ -9,53 +9,12 @@ from __future__ import annotations
from collections.abc import Mapping
from typing import TYPE_CHECKING, Any
from app.services.provider_api_base import resolve_api_base
if TYPE_CHECKING:
from app.db import Connection
PROTOCOL_OLLAMA = "OLLAMA"
PROTOCOL_OPENAI_COMPATIBLE = "OPENAI_COMPATIBLE"
PROTOCOL_NATIVE = "NATIVE"
NATIVE_PROVIDER_PREFIX: dict[str, str] = {
"OPENAI": "openai",
"ANTHROPIC": "anthropic",
"GROQ": "groq",
"COHERE": "cohere",
"GOOGLE": "gemini",
"MISTRAL": "mistral",
"AZURE_OPENAI": "azure",
"AZURE": "azure",
"OPENROUTER": "openrouter",
"COMETAPI": "cometapi",
"XAI": "xai",
"BEDROCK": "bedrock",
"AWS_BEDROCK": "bedrock",
"VERTEX_AI": "vertex_ai",
"TOGETHER_AI": "together_ai",
"FIREWORKS_AI": "fireworks_ai",
"DEEPSEEK": "openai",
"ALIBABA_QWEN": "openai",
"MOONSHOT": "openai",
"ZHIPU": "openai",
"GITHUB_MODELS": "github",
"REPLICATE": "replicate",
"PERPLEXITY": "perplexity",
"ANYSCALE": "anyscale",
"DEEPINFRA": "deepinfra",
"CEREBRAS": "cerebras",
"SAMBANOVA": "sambanova",
"AI21": "ai21",
"CLOUDFLARE": "cloudflare",
"DATABRICKS": "databricks",
"HUGGINGFACE": "huggingface",
"MINIMAX": "openai",
"RECRAFT": "recraft",
"XINFERENCE": "xinference",
"NSCALE": "nscale",
"CUSTOM": "custom",
}
PROTOCOL_ANTHROPIC = "ANTHROPIC"
def ensure_v1(base_url: str | None) -> str | None:
@ -77,6 +36,23 @@ def _protocol_value(protocol: Any) -> str:
return getattr(protocol, "value", str(protocol))
def default_litellm_provider(protocol: Any) -> str:
protocol_value = _protocol_value(protocol)
defaults = {
PROTOCOL_OLLAMA: "ollama_chat",
PROTOCOL_ANTHROPIC: "anthropic",
PROTOCOL_OPENAI_COMPATIBLE: "openai",
}
return defaults.get(protocol_value, "openai")
def _execution_api_base(protocol: str, base_url: str | None) -> str | None:
del protocol
if not base_url:
return None
return base_url.rstrip("/")
def to_litellm(
conn: Connection | Mapping[str, Any],
model_id: str,
@ -85,38 +61,19 @@ def to_litellm(
protocol = _protocol_value(_conn_value(conn, "protocol"))
base_url = _conn_value(conn, "base_url")
api_key = _conn_value(conn, "api_key")
native_provider = _conn_value(conn, "native_provider")
litellm_provider = (
_conn_value(conn, "litellm_provider") or default_litellm_provider(protocol)
)
extra = _conn_value(conn, "extra") or {}
kwargs: dict[str, Any] = {}
if api_key:
kwargs["api_key"] = api_key
if protocol == PROTOCOL_OLLAMA:
model_string = f"ollama_chat/{model_id}"
if base_url:
kwargs["api_base"] = base_url.rstrip("/")
elif protocol == PROTOCOL_OPENAI_COMPATIBLE:
model_string = f"openai/{model_id}"
api_base = ensure_v1(base_url)
if api_base:
kwargs["api_base"] = api_base
else:
provider_key = (native_provider or "").upper()
prefix = NATIVE_PROVIDER_PREFIX.get(provider_key, provider_key.lower())
if prefix == "custom":
custom_provider = extra.get("custom_provider") or native_provider
model_string = f"{custom_provider}/{model_id}" if custom_provider else model_id
else:
model_string = f"{prefix}/{model_id}"
api_base = resolve_api_base(
provider=provider_key,
provider_prefix=prefix,
config_api_base=base_url,
)
if api_base:
kwargs["api_base"] = api_base
model_string = f"{litellm_provider}/{model_id}" if litellm_provider else model_id
api_base = _execution_api_base(protocol, base_url)
if api_base:
kwargs["api_base"] = api_base
if api_version := extra.get("api_version"):
kwargs["api_version"] = api_version
@ -126,18 +83,21 @@ def to_litellm(
def native_connection_from_config(config: Mapping[str, Any]) -> dict[str, Any]:
"""Build an in-memory NATIVE connection mapping from a legacy/global config."""
provider = str(config.get("provider") or config.get("custom_provider") or "CUSTOM")
"""Build an in-memory connection mapping from a global config."""
protocol = str(config.get("protocol") or PROTOCOL_OPENAI_COMPATIBLE)
litellm_provider = str(
config.get("litellm_provider")
or config.get("custom_provider")
or default_litellm_provider(protocol)
)
extra: dict[str, Any] = {
"litellm_params": config.get("litellm_params") or {},
}
if config.get("api_version"):
extra["api_version"] = config.get("api_version")
if config.get("custom_provider"):
extra["custom_provider"] = config.get("custom_provider")
return {
"protocol": PROTOCOL_NATIVE,
"native_provider": provider,
"protocol": protocol,
"litellm_provider": litellm_provider,
"base_url": config.get("api_base") or None,
"api_key": config.get("api_key") or None,
"extra": extra,
@ -145,7 +105,7 @@ def native_connection_from_config(config: Mapping[str, Any]) -> dict[str, Any]:
__all__ = [
"NATIVE_PROVIDER_PREFIX",
"default_litellm_provider",
"ensure_v1",
"native_connection_from_config",
"to_litellm",

View file

@ -3,19 +3,12 @@ from typing import Any
from litellm import Router
from app.services.model_resolver import (
NATIVE_PROVIDER_PREFIX,
native_connection_from_config,
to_litellm,
)
from app.services.model_resolver import native_connection_from_config, to_litellm
logger = logging.getLogger(__name__)
VISION_AUTO_MODE_ID = 0
VISION_PROVIDER_MAP = NATIVE_PROVIDER_PREFIX
class VisionLLMRouterService:
_instance = None
_router: Router | None = None
@ -141,12 +134,11 @@ def is_vision_auto_mode(config_id: int | None) -> bool:
def build_vision_model_string(
provider: str, model_name: str, custom_provider: str | None
litellm_provider: str, model_name: str, custom_provider: str | None
) -> str:
if custom_provider:
return f"{custom_provider}/{model_name}"
prefix = VISION_PROVIDER_MAP.get(provider.upper(), provider.lower())
return f"{prefix}/{model_name}"
return f"{litellm_provider}/{model_name}"
def get_global_vision_llm_config(config_id: int) -> dict | None:

View file

@ -97,16 +97,16 @@ def _process_vision_models(raw_models: list[dict]) -> list[dict]:
}
)
native_provider = OPENROUTER_SLUG_TO_VISION_PROVIDER.get(provider_slug)
if native_provider:
if native_provider == "GOOGLE" and not model_name.startswith("gemini-"):
direct_provider = OPENROUTER_SLUG_TO_VISION_PROVIDER.get(provider_slug)
if direct_provider:
if direct_provider == "GOOGLE" and not model_name.startswith("gemini-"):
continue
processed.append(
{
"value": model_name,
"label": name,
"provider": native_provider,
"provider": direct_provider,
"context_window": context_window,
}
)

View file

@ -40,7 +40,7 @@ def check_image_input_capability(
else None
)
if not is_known_text_only_chat_model(
provider=agent_config.provider,
litellm_provider=agent_config.provider,
model_name=agent_config.model_name,
base_model=agent_base_model,
custom_provider=agent_config.custom_provider,

View file

@ -80,7 +80,6 @@ async def _generate_title(
from litellm import acompletion
from app.services.llm_router_service import LLMRouterService
from app.services.provider_api_base import resolve_api_base
from app.services.token_tracking_service import _turn_accumulator
# Excludes this turn's own assistant row (pre-written by
@ -125,26 +124,12 @@ async def _generate_title(
router = LLMRouterService.get_router()
response = await router.acompletion(model="auto", messages=messages)
else:
# Apply the same ``api_base`` cascade chat / vision / image-gen
# call sites use so we never inherit ``litellm.api_base``
# (commonly set by ``AZURE_OPENAI_ENDPOINT``) when the chat
# config itself ships an empty ``api_base``. Without this the
# title-gen on an OpenRouter chat config would 404 against the
# inherited Azure endpoint — see ``provider_api_base`` for the
# same bug repro on the image-gen / vision paths.
raw_model = getattr(llm, "model", "") or ""
provider_prefix = raw_model.split("/", 1)[0] if "/" in raw_model else None
provider_value = agent_config.provider if agent_config is not None else None
title_api_base = resolve_api_base(
provider=provider_value,
provider_prefix=provider_prefix,
config_api_base=getattr(llm, "api_base", None),
)
response = await acompletion(
model=raw_model,
messages=messages,
api_key=getattr(llm, "api_key", None),
api_base=title_api_base,
api_base=getattr(llm, "api_base", None),
)
usage_info = None

View file

@ -1,8 +1,8 @@
"""Load an LLM + AgentConfig bundle for a given config id.
Handles both code paths uniformly:
- ``config_id >= 0`` database-backed ``NewLLMConfig`` row (per-user/per-space).
- ``config_id < 0`` YAML-defined global LLM config (built-in defaults).
- ``config_id > 0`` database-backed model-connection ``Model`` row.
- ``config_id < 0`` virtual global model materialized from YAML/OpenRouter.
Returns ``(llm, agent_config, error_message)``; on success ``error_message`` is
``None``. The caller emits the friendly SSE error frame.
@ -12,15 +12,72 @@ from __future__ import annotations
from typing import Any
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from app.agents.chat.runtime.llm_config import (
AgentConfig,
create_chat_litellm_from_agent_config,
create_chat_litellm_from_config,
load_agent_config,
load_global_llm_config_by_id,
SanitizedChatLiteLLM,
)
from app.config import config
from app.db import Model, SearchSpace
from app.services.model_resolver import to_litellm
def _agent_config_from_resolved(
*,
config_id: int,
config_name: str | None,
provider: str,
model_name: str,
api_key: str | None,
api_base: str | None,
litellm_params: dict | None,
supports_image_input: bool,
billing_tier: str = "free",
) -> AgentConfig:
return AgentConfig(
provider=provider,
model_name=model_name,
api_key=api_key or "",
api_base=api_base,
custom_provider=None,
litellm_params=litellm_params,
config_id=config_id,
config_name=config_name,
is_auto_mode=False,
billing_tier=billing_tier,
is_premium=billing_tier == "premium",
supports_image_input=supports_image_input,
)
async def _load_search_space(session: AsyncSession, search_space_id: int) -> SearchSpace | None:
result = await session.execute(select(SearchSpace).where(SearchSpace.id == search_space_id))
return result.scalars().first()
async def _load_db_model(
session: AsyncSession,
*,
model_id: int,
search_space: SearchSpace,
) -> Model | None:
result = await session.execute(
select(Model)
.options(selectinload(Model.connection))
.where(Model.id == model_id, Model.enabled.is_(True))
)
model = result.scalars().first()
if not model or not model.connection or not model.connection.enabled:
return None
conn = model.connection
if conn.search_space_id is not None and conn.search_space_id != search_space.id:
return None
if conn.user_id is not None and conn.user_id != search_space.user_id:
return None
return model
async def load_llm_bundle(
@ -29,29 +86,67 @@ async def load_llm_bundle(
config_id: int,
search_space_id: int,
) -> tuple[Any, AgentConfig | None, str | None]:
if config_id >= 0:
loaded_agent_config = await load_agent_config(
session=session,
config_id=config_id,
search_space_id=search_space_id,
search_space = await _load_search_space(session, search_space_id)
if not search_space:
return None, None, f"Search space {search_space_id} not found"
if config_id > 0:
model = await _load_db_model(
session,
model_id=config_id,
search_space=search_space,
)
if not loaded_agent_config:
if not model or not (model.capabilities or {}).get("chat"):
return (
None,
None,
f"Failed to load NewLLMConfig with id {config_id}",
f"Failed to load chat model with id {config_id}",
)
model_string, litellm_kwargs = to_litellm(model.connection, model.model_id)
agent_config = _agent_config_from_resolved(
config_id=config_id,
config_name=model.display_name or model.model_id,
provider=model.connection.litellm_provider or "",
model_name=model.model_id,
api_key=model.connection.api_key,
api_base=model.connection.base_url,
litellm_params=(model.connection.extra or {}).get("litellm_params"),
supports_image_input=bool((model.capabilities or {}).get("vision")),
billing_tier="free",
)
return (
create_chat_litellm_from_agent_config(loaded_agent_config),
loaded_agent_config,
SanitizedChatLiteLLM(model=model_string, **litellm_kwargs),
agent_config,
None,
)
loaded_llm_config = load_global_llm_config_by_id(config_id)
if not loaded_llm_config:
return None, None, f"Failed to load LLM config with id {config_id}"
return (
create_chat_litellm_from_config(loaded_llm_config),
AgentConfig.from_yaml_config(loaded_llm_config),
global_model = next((m for m in config.GLOBAL_MODELS if m.get("id") == config_id), None)
if not global_model or not (global_model.get("capabilities") or {}).get("chat"):
return None, None, f"Failed to load global chat model with id {config_id}"
global_connection = next(
(
c
for c in config.GLOBAL_CONNECTIONS
if c.get("id") == global_model.get("connection_id")
),
None,
)
if not global_connection:
return None, None, f"Failed to load global connection for model {config_id}"
model_string, litellm_kwargs = to_litellm(global_connection, global_model["model_id"])
agent_config = _agent_config_from_resolved(
config_id=config_id,
config_name=global_model.get("display_name") or global_model.get("model_id"),
provider=global_connection.get("litellm_provider") or "",
model_name=global_model["model_id"],
api_key=global_connection.get("api_key"),
api_base=global_connection.get("base_url"),
litellm_params=(global_connection.get("extra") or {}).get("litellm_params"),
supports_image_input=bool((global_model.get("capabilities") or {}).get("vision")),
billing_tier=str(global_model.get("billing_tier", "free")).lower(),
)
return (
SanitizedChatLiteLLM(model=model_string, **litellm_kwargs),
agent_config,
None,
)

View file

@ -75,10 +75,10 @@ async def test_auto_first_turn_pins_one_model(monkeypatch):
config,
"GLOBAL_LLM_CONFIGS",
[
{"id": -2, "provider": "OPENAI", "model_name": "gpt-free", "api_key": "k1"},
{"id": -2, "litellm_provider": "openai", "model_name": "gpt-free", "api_key": "k1"},
{
"id": -1,
"provider": "OPENAI",
"litellm_provider": "openai",
"model_name": "gpt-prem",
"api_key": "k2",
"billing_tier": "premium",
@ -117,7 +117,7 @@ async def test_premium_eligible_auto_prefers_premium_over_free(monkeypatch):
[
{
"id": -2,
"provider": "OPENAI",
"litellm_provider": "openai",
"model_name": "gpt-free",
"api_key": "k1",
"billing_tier": "free",
@ -125,7 +125,7 @@ async def test_premium_eligible_auto_prefers_premium_over_free(monkeypatch):
},
{
"id": -1,
"provider": "OPENAI",
"litellm_provider": "openai",
"model_name": "gpt-prem",
"api_key": "k2",
"billing_tier": "premium",
@ -164,7 +164,7 @@ async def test_premium_eligible_auto_prefers_azure_gpt_5_4(monkeypatch):
[
{
"id": -1,
"provider": "AZURE_OPENAI",
"litellm_provider": "azure",
"model_name": "gpt-5.1",
"api_key": "k1",
"billing_tier": "premium",
@ -173,7 +173,7 @@ async def test_premium_eligible_auto_prefers_azure_gpt_5_4(monkeypatch):
},
{
"id": -2,
"provider": "AZURE_OPENAI",
"litellm_provider": "azure",
"model_name": "gpt-5.4",
"api_key": "k2",
"billing_tier": "premium",
@ -182,7 +182,7 @@ async def test_premium_eligible_auto_prefers_azure_gpt_5_4(monkeypatch):
},
{
"id": -3,
"provider": "OPENROUTER",
"litellm_provider": "openrouter",
"model_name": "openai/gpt-5.4",
"api_key": "k3",
"billing_tier": "premium",
@ -222,7 +222,7 @@ async def test_next_turn_reuses_existing_pin(monkeypatch):
[
{
"id": -1,
"provider": "OPENAI",
"litellm_provider": "openai",
"model_name": "gpt-prem",
"api_key": "k2",
"billing_tier": "premium",
@ -263,7 +263,7 @@ async def test_premium_eligible_auto_can_pin_premium(monkeypatch):
[
{
"id": -1,
"provider": "OPENAI",
"litellm_provider": "openai",
"model_name": "gpt-prem",
"api_key": "k2",
"billing_tier": "premium",
@ -301,14 +301,14 @@ async def test_premium_ineligible_auto_pins_free_only(monkeypatch):
[
{
"id": -2,
"provider": "OPENAI",
"litellm_provider": "openai",
"model_name": "gpt-free",
"api_key": "k1",
"billing_tier": "free",
},
{
"id": -1,
"provider": "OPENAI",
"litellm_provider": "openai",
"model_name": "gpt-prem",
"api_key": "k2",
"billing_tier": "premium",
@ -346,14 +346,14 @@ async def test_pinned_premium_stays_premium_after_quota_exhaustion(monkeypatch):
[
{
"id": -2,
"provider": "OPENAI",
"litellm_provider": "openai",
"model_name": "gpt-free",
"api_key": "k1",
"billing_tier": "free",
},
{
"id": -1,
"provider": "OPENAI",
"litellm_provider": "openai",
"model_name": "gpt-prem",
"api_key": "k2",
"billing_tier": "premium",
@ -391,14 +391,14 @@ async def test_force_repin_free_switches_auto_premium_pin_to_free(monkeypatch):
[
{
"id": -2,
"provider": "OPENAI",
"litellm_provider": "openai",
"model_name": "gpt-free",
"api_key": "k1",
"billing_tier": "free",
},
{
"id": -1,
"provider": "OPENAI",
"litellm_provider": "openai",
"model_name": "gpt-prem",
"api_key": "k2",
"billing_tier": "premium",
@ -437,7 +437,7 @@ async def test_explicit_user_model_change_clears_pin(monkeypatch):
config,
"GLOBAL_LLM_CONFIGS",
[
{"id": -2, "provider": "OPENAI", "model_name": "gpt-free", "api_key": "k1"},
{"id": -2, "litellm_provider": "openai", "model_name": "gpt-free", "api_key": "k1"},
],
)
@ -462,7 +462,7 @@ async def test_invalid_pinned_config_repairs_with_new_pin(monkeypatch):
config,
"GLOBAL_LLM_CONFIGS",
[
{"id": -2, "provider": "OPENAI", "model_name": "gpt-free", "api_key": "k1"},
{"id": -2, "litellm_provider": "openai", "model_name": "gpt-free", "api_key": "k1"},
],
)
@ -504,7 +504,7 @@ async def test_health_gated_config_is_excluded_from_selection(monkeypatch):
[
{
"id": -1,
"provider": "OPENROUTER",
"litellm_provider": "openrouter",
"model_name": "venice/dead-model",
"api_key": "k1",
"billing_tier": "free",
@ -514,7 +514,7 @@ async def test_health_gated_config_is_excluded_from_selection(monkeypatch):
},
{
"id": -2,
"provider": "OPENROUTER",
"litellm_provider": "openrouter",
"model_name": "google/gemini-flash",
"api_key": "k1",
"billing_tier": "free",
@ -556,7 +556,7 @@ async def test_tier_a_locks_first_premium_user_skips_or(monkeypatch):
[
{
"id": -1,
"provider": "AZURE_OPENAI",
"litellm_provider": "azure",
"model_name": "gpt-5",
"api_key": "k-yaml",
"billing_tier": "premium",
@ -566,7 +566,7 @@ async def test_tier_a_locks_first_premium_user_skips_or(monkeypatch):
},
{
"id": -2,
"provider": "OPENROUTER",
"litellm_provider": "openrouter",
"model_name": "openai/gpt-5",
"api_key": "k-or",
"billing_tier": "premium",
@ -608,7 +608,7 @@ async def test_tier_a_falls_through_to_or_when_a_pool_empty_for_user(monkeypatch
[
{
"id": -1,
"provider": "AZURE_OPENAI",
"litellm_provider": "azure",
"model_name": "gpt-5",
"api_key": "k-yaml",
"billing_tier": "premium",
@ -618,7 +618,7 @@ async def test_tier_a_falls_through_to_or_when_a_pool_empty_for_user(monkeypatch
},
{
"id": -2,
"provider": "OPENROUTER",
"litellm_provider": "openrouter",
"model_name": "google/gemini-flash:free",
"api_key": "k-or",
"billing_tier": "free",
@ -656,7 +656,7 @@ async def test_top_k_picks_only_high_score_models(monkeypatch):
high_score_cfgs = [
{
"id": -i,
"provider": "AZURE_OPENAI",
"litellm_provider": "azure",
"model_name": f"gpt-x-{i}",
"api_key": "k",
"billing_tier": "premium",
@ -668,7 +668,7 @@ async def test_top_k_picks_only_high_score_models(monkeypatch):
]
low_score_trap = {
"id": -99,
"provider": "AZURE_OPENAI",
"litellm_provider": "azure",
"model_name": "tiny-legacy",
"api_key": "k",
"billing_tier": "premium",
@ -729,7 +729,7 @@ async def test_pin_reuse_survives_health_gating_for_existing_pin(monkeypatch):
[
{
"id": -1,
"provider": "OPENROUTER",
"litellm_provider": "openrouter",
"model_name": "venice/dead-model",
"api_key": "k",
"billing_tier": "premium",
@ -739,7 +739,7 @@ async def test_pin_reuse_survives_health_gating_for_existing_pin(monkeypatch):
},
{
"id": -2,
"provider": "AZURE_OPENAI",
"litellm_provider": "azure",
"model_name": "gpt-5",
"api_key": "k",
"billing_tier": "premium",
@ -781,7 +781,7 @@ async def test_pin_reuse_regression_existing_healthy_pin(monkeypatch):
[
{
"id": -1,
"provider": "AZURE_OPENAI",
"litellm_provider": "azure",
"model_name": "gpt-5",
"api_key": "k",
"billing_tier": "premium",
@ -791,7 +791,7 @@ async def test_pin_reuse_regression_existing_healthy_pin(monkeypatch):
},
{
"id": -2,
"provider": "AZURE_OPENAI",
"litellm_provider": "azure",
"model_name": "gpt-5-pro",
"api_key": "k",
"billing_tier": "premium",
@ -839,7 +839,7 @@ async def test_runtime_cooled_down_pin_is_not_reused(monkeypatch):
[
{
"id": -1,
"provider": "OPENROUTER",
"litellm_provider": "openrouter",
"model_name": "google/gemma-4-26b-a4b-it:free",
"api_key": "k",
"billing_tier": "free",
@ -849,7 +849,7 @@ async def test_runtime_cooled_down_pin_is_not_reused(monkeypatch):
},
{
"id": -2,
"provider": "OPENROUTER",
"litellm_provider": "openrouter",
"model_name": "google/gemini-2.5-flash:free",
"api_key": "k",
"billing_tier": "free",
@ -892,7 +892,7 @@ async def test_clearing_runtime_cooldown_restores_pin_reuse(monkeypatch):
[
{
"id": -1,
"provider": "OPENROUTER",
"litellm_provider": "openrouter",
"model_name": "google/gemma-4-26b-a4b-it:free",
"api_key": "k",
"billing_tier": "free",
@ -937,7 +937,7 @@ async def test_auto_pin_repin_excludes_previous_config_on_runtime_retry(monkeypa
[
{
"id": -1,
"provider": "OPENROUTER",
"litellm_provider": "openrouter",
"model_name": "google/gemma-4-26b-a4b-it:free",
"api_key": "k",
"billing_tier": "free",
@ -947,7 +947,7 @@ async def test_auto_pin_repin_excludes_previous_config_on_runtime_retry(monkeypa
},
{
"id": -2,
"provider": "OPENROUTER",
"litellm_provider": "openrouter",
"model_name": "google/gemini-2.5-flash:free",
"api_key": "k",
"billing_tier": "free",

View file

@ -74,7 +74,7 @@ def _thread(*, pinned: int | None = None):
def _vision_cfg(id_: int, *, tier: str = "free", quality: int = 80) -> dict:
return {
"id": id_,
"provider": "OPENAI",
"litellm_provider": "openai",
"model_name": f"vision-{id_}",
"api_key": "k",
"billing_tier": tier,
@ -87,7 +87,7 @@ def _vision_cfg(id_: int, *, tier: str = "free", quality: int = 80) -> dict:
def _text_only_cfg(id_: int, *, tier: str = "free", quality: int = 90) -> dict:
return {
"id": id_,
"provider": "OPENAI",
"litellm_provider": "openai",
"model_name": f"text-{id_}",
"api_key": "k",
"billing_tier": tier,
@ -261,7 +261,7 @@ async def test_image_turn_unannotated_cfg_resolves_via_helper(monkeypatch):
session = _FakeSession(_thread())
cfg_unannotated_vision = {
"id": -2,
"provider": "OPENAI",
"litellm_provider": "openai",
"model_name": "gpt-4o", # known vision model in LiteLLM map
"api_key": "k",
"billing_tier": "free",

View file

@ -25,10 +25,10 @@ def _fake_yaml_config(
return {
"id": id,
"name": f"yaml-{id}",
"provider": "OPENAI",
"litellm_provider": "openai",
"model_name": model_name,
"api_key": "sk-test",
"api_base": "",
"api_base": "https://api.openai.com/v1",
"billing_tier": billing_tier,
"rpm": 100,
"tpm": 100_000,
@ -54,10 +54,10 @@ def _fake_openrouter_config(
return {
"id": id,
"name": f"or-{id}",
"provider": "OPENROUTER",
"litellm_provider": "openrouter",
"model_name": model_name,
"api_key": "sk-or-test",
"api_base": "",
"api_base": "https://openrouter.ai/api/v1",
"billing_tier": billing_tier,
"rpm": 20 if billing_tier == "free" else 200,
"tpm": 100_000 if billing_tier == "free" else 1_000_000,

View file

@ -25,7 +25,7 @@ def _or_cfg(
) -> dict:
return {
"id": cid,
"provider": "OPENROUTER",
"litellm_provider": "openrouter",
"model_name": model_name,
"billing_tier": tier,
"auto_pin_tier": "B" if tier == "premium" else "C",
@ -144,7 +144,7 @@ async def test_enrich_health_only_touches_or_provider(monkeypatch):
"""YAML cfgs that aren't OPENROUTER must be skipped entirely."""
yaml_cfg = {
"id": -1,
"provider": "AZURE_OPENAI",
"litellm_provider": "azure",
"model_name": "gpt-5",
"billing_tier": "premium",
"auto_pin_tier": "A",
@ -313,7 +313,7 @@ async def test_enrich_health_no_or_cfgs_is_noop(monkeypatch):
"""When the catalogue has no OR cfgs at all, no HTTP calls fire."""
yaml_cfg: dict[str, Any] = {
"id": -1,
"provider": "AZURE_OPENAI",
"litellm_provider": "azure",
"model_name": "gpt-5",
"billing_tier": "premium",
}

View file

@ -35,7 +35,7 @@ def test_safety_net_does_not_fire_for_azure_gpt_4o():
it text-only."""
assert (
is_known_text_only_chat_model(
provider="AZURE_OPENAI",
litellm_provider="azure",
model_name="my-azure-deployment",
base_model="gpt-4o",
)
@ -49,7 +49,7 @@ def test_safety_net_does_not_fire_for_unknown_model():
LiteLLM doesn't know about must flow through to the provider."""
assert (
is_known_text_only_chat_model(
provider="CUSTOM",
litellm_provider="custom",
custom_provider="brand_new_proxy",
model_name="brand-new-model-x9",
)
@ -69,7 +69,7 @@ def test_safety_net_does_not_fire_when_lookup_raises(monkeypatch):
assert (
is_known_text_only_chat_model(
provider="OPENAI",
litellm_provider="openai",
model_name="gpt-4o",
)
is False
@ -88,7 +88,7 @@ def test_safety_net_fires_only_on_explicit_false(monkeypatch):
monkeypatch.setattr(pc.litellm, "get_model_info", _info_explicit_false)
assert (
is_known_text_only_chat_model(
provider="OPENAI",
litellm_provider="openai",
model_name="text-only-stub",
)
is True
@ -100,7 +100,7 @@ def test_safety_net_fires_only_on_explicit_false(monkeypatch):
monkeypatch.setattr(pc.litellm, "get_model_info", _info_true)
assert (
is_known_text_only_chat_model(
provider="OPENAI",
litellm_provider="openai",
model_name="vision-stub",
)
is False
@ -112,7 +112,7 @@ def test_safety_net_fires_only_on_explicit_false(monkeypatch):
monkeypatch.setattr(pc.litellm, "get_model_info", _info_missing)
assert (
is_known_text_only_chat_model(
provider="OPENAI",
litellm_provider="openai",
model_name="missing-key-stub",
)
is False