feat(chat): route models by provider capabilities

This commit is contained in:
Anish Sarkar 2026-06-11 18:22:23 +05:30
parent 8f20a32571
commit c28c4f5785
18 changed files with 429 additions and 319 deletions

View file

@ -2,9 +2,9 @@
LLM configuration utilities for SurfSense agents. LLM configuration utilities for SurfSense agents.
This module provides functions for loading LLM configurations from: This module provides functions for loading LLM configurations from:
1. Auto mode (ID 0) - Uses LiteLLM Router for load balancing 1. Auto mode (ID 0) - Resolved by callers to a concrete model-connection model
2. YAML files (global configs with negative IDs) 2. YAML files (global configs with negative IDs)
3. Database NewLLMConfig table (user-created configs with positive IDs) 3. Database model-connections table (user-created configs with positive IDs)
It also provides utilities for creating ChatLiteLLM instances and It also provides utilities for creating ChatLiteLLM instances and
managing prompt configurations. managing prompt configurations.
@ -33,9 +33,7 @@ from app.agents.chat.runtime.prompt_caching import (
from app.services.llm_router_service import ( from app.services.llm_router_service import (
AUTO_MODE_ID, AUTO_MODE_ID,
ChatLiteLLMRouter, ChatLiteLLMRouter,
LLMRouterService,
_sanitize_content, _sanitize_content,
get_auto_mode_llm,
is_auto_mode, is_auto_mode,
) )
@ -92,14 +90,6 @@ class SanitizedChatLiteLLM(ChatLiteLLM):
yield chunk yield chunk
# Re-exported under the historical name ``PROVIDER_MAP``. Source of truth lives
# in provider_capabilities so the YAML loader can resolve prefixes during
# app.config init without importing the agent/tools tree.
from app.services.provider_capabilities import ( # noqa: E402
_PROVIDER_PREFIX_MAP as PROVIDER_MAP,
)
def _attach_model_profile(llm: ChatLiteLLM, model_string: str) -> None: def _attach_model_profile(llm: ChatLiteLLM, model_string: str) -> None:
"""Attach a ``profile`` dict to ChatLiteLLM with model context metadata.""" """Attach a ``profile`` dict to ChatLiteLLM with model context metadata."""
try: try:
@ -122,7 +112,8 @@ class AgentConfig:
Complete configuration for the SurfSense agent. Complete configuration for the SurfSense agent.
This combines LLM settings with prompt configuration from NewLLMConfig. This combines LLM settings with prompt configuration from NewLLMConfig.
Supports Auto mode (ID 0) which uses LiteLLM Router for load balancing. Supports Auto mode metadata (ID 0). Runtime callers must resolve Auto to
a concrete global or BYOK model before constructing ChatLiteLLM.
""" """
# LLM Model Settings # LLM Model Settings
@ -219,7 +210,7 @@ class AgentConfig:
# BYOK rows have no curated flag; ask LiteLLM (default-allow on # BYOK rows have no curated flag; ask LiteLLM (default-allow on
# unknown). The streaming safety net still blocks explicit text-only. # unknown). The streaming safety net still blocks explicit text-only.
supports_image_input=derive_supports_image_input( supports_image_input=derive_supports_image_input(
provider=provider_value, litellm_provider=provider_value.lower(),
model_name=config.model_name, model_name=config.model_name,
base_model=base_model, base_model=base_model,
custom_provider=config.custom_provider, custom_provider=config.custom_provider,
@ -238,7 +229,7 @@ class AgentConfig:
system_instructions = yaml_config.get("system_instructions", "") system_instructions = yaml_config.get("system_instructions", "")
provider = yaml_config.get("provider", "").upper() provider = yaml_config.get("litellm_provider", "")
model_name = yaml_config.get("model_name", "") model_name = yaml_config.get("model_name", "")
custom_provider = yaml_config.get("custom_provider") custom_provider = yaml_config.get("custom_provider")
litellm_params = yaml_config.get("litellm_params") or {} litellm_params = yaml_config.get("litellm_params") or {}
@ -254,7 +245,7 @@ class AgentConfig:
supports_image_input = bool(yaml_config.get("supports_image_input")) supports_image_input = bool(yaml_config.get("supports_image_input"))
else: else:
supports_image_input = derive_supports_image_input( supports_image_input = derive_supports_image_input(
provider=provider, litellm_provider=provider,
model_name=model_name, model_name=model_name,
base_model=base_model, base_model=base_model,
custom_provider=custom_provider, custom_provider=custom_provider,
@ -383,9 +374,6 @@ async def load_agent_config(
) -> "AgentConfig | None": ) -> "AgentConfig | None":
"""Main config loader: id 0 -> Auto mode; negative -> YAML; positive -> DB.""" """Main config loader: id 0 -> Auto mode; negative -> YAML; positive -> DB."""
if is_auto_mode(config_id): if is_auto_mode(config_id):
if not LLMRouterService.is_initialized():
print("Error: Auto mode requested but LLM Router not initialized")
return None
return AgentConfig.from_auto_mode() return AgentConfig.from_auto_mode()
if config_id < 0: if config_id < 0:
@ -408,9 +396,8 @@ def create_chat_litellm_from_config(llm_config: dict) -> ChatLiteLLM | None:
if llm_config.get("custom_provider"): if llm_config.get("custom_provider"):
model_string = f"{llm_config['custom_provider']}/{llm_config['model_name']}" model_string = f"{llm_config['custom_provider']}/{llm_config['model_name']}"
else: else:
provider = llm_config.get("provider", "").upper() litellm_provider = llm_config.get("litellm_provider", "openai")
provider_prefix = PROVIDER_MAP.get(provider, provider.lower()) model_string = f"{litellm_provider}/{llm_config['model_name']}"
model_string = f"{provider_prefix}/{llm_config['model_name']}"
litellm_kwargs = { litellm_kwargs = {
"model": model_string, "model": model_string,
@ -433,29 +420,15 @@ def create_chat_litellm_from_config(llm_config: dict) -> ChatLiteLLM | None:
def create_chat_litellm_from_agent_config( def create_chat_litellm_from_agent_config(
agent_config: AgentConfig, agent_config: AgentConfig,
) -> ChatLiteLLM | ChatLiteLLMRouter | None: ) -> ChatLiteLLM | ChatLiteLLMRouter | None:
"""Create a ChatLiteLLM (or, for Auto mode, a load-balancing router) from config.""" """Create a ChatLiteLLM from an already resolved concrete model config."""
if agent_config.is_auto_mode: if agent_config.is_auto_mode:
if not LLMRouterService.is_initialized(): print("Error: Auto mode must be resolved to a concrete model before LLM creation")
print("Error: Auto mode requested but LLM Router not initialized") return None
return None
try:
router_llm = get_auto_mode_llm()
if router_llm is not None:
# Universal injection points only: auto-mode fans out across
# providers, so provider-specific kwargs have no known target.
apply_litellm_prompt_caching(router_llm, agent_config=agent_config)
return router_llm
except Exception as e:
print(f"Error creating ChatLiteLLMRouter: {e}")
return None
if agent_config.custom_provider: if agent_config.custom_provider:
model_string = f"{agent_config.custom_provider}/{agent_config.model_name}" model_string = f"{agent_config.custom_provider}/{agent_config.model_name}"
else: else:
provider_prefix = PROVIDER_MAP.get( model_string = f"{agent_config.provider}/{agent_config.model_name}"
agent_config.provider, agent_config.provider.lower()
)
model_string = f"{provider_prefix}/{agent_config.model_name}"
litellm_kwargs = { litellm_kwargs = {
"model": model_string, "model": model_string,

View file

@ -132,7 +132,7 @@ async def list_anonymous_models():
id=cfg.get("id", 0), id=cfg.get("id", 0),
name=cfg.get("name", ""), name=cfg.get("name", ""),
description=cfg.get("description"), description=cfg.get("description"),
provider=cfg.get("provider", ""), provider=cfg.get("litellm_provider", ""),
model_name=cfg.get("model_name", ""), model_name=cfg.get("model_name", ""),
billing_tier=cfg.get("billing_tier", "free"), billing_tier=cfg.get("billing_tier", "free"),
is_premium=cfg.get("billing_tier", "free") == "premium", is_premium=cfg.get("billing_tier", "free") == "premium",
@ -161,7 +161,7 @@ async def get_anonymous_model(slug: str):
id=cfg.get("id", 0), id=cfg.get("id", 0),
name=cfg.get("name", ""), name=cfg.get("name", ""),
description=cfg.get("description"), description=cfg.get("description"),
provider=cfg.get("provider", ""), provider=cfg.get("litellm_provider", ""),
model_name=cfg.get("model_name", ""), model_name=cfg.get("model_name", ""),
billing_tier=cfg.get("billing_tier", "free"), billing_tier=cfg.get("billing_tier", "free"),
is_premium=cfg.get("billing_tier", "free") == "premium", is_premium=cfg.get("billing_tier", "free") == "premium",

View file

@ -96,7 +96,7 @@ async def get_global_vision_llm_configs(
"id": cfg.get("id"), "id": cfg.get("id"),
"name": cfg.get("name"), "name": cfg.get("name"),
"description": cfg.get("description"), "description": cfg.get("description"),
"provider": cfg.get("provider"), "provider": cfg.get("litellm_provider"),
"custom_provider": cfg.get("custom_provider"), "custom_provider": cfg.get("custom_provider"),
"model_name": cfg.get("model_name"), "model_name": cfg.get("model_name"),
"api_base": cfg.get("api_base") or None, "api_base": cfg.get("api_base") or None,

View file

@ -23,9 +23,10 @@ from uuid import UUID
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from app.config import config from app.config import config
from app.db import NewChatThread from app.db import Connection, Model, NewChatThread
from app.services.quality_score import _QUALITY_TOP_K from app.services.quality_score import _QUALITY_TOP_K
from app.services.token_quota_service import TokenQuotaService from app.services.token_quota_service import TokenQuotaService
@ -61,11 +62,20 @@ def _is_usable_global_config(cfg: dict) -> bool:
return bool( return bool(
cfg.get("id") is not None cfg.get("id") is not None
and cfg.get("model_name") and cfg.get("model_name")
and cfg.get("provider") and cfg.get("litellm_provider")
and cfg.get("api_key") and cfg.get("api_key")
) )
def _has_capability(model: dict | Model, capability: str) -> bool:
caps = (
model.get("capabilities", {})
if isinstance(model, dict)
else model.capabilities or {}
)
return bool(caps.get(capability))
def _prune_runtime_cooldowns(now_ts: float | None = None) -> None: def _prune_runtime_cooldowns(now_ts: float | None = None) -> None:
now = time.time() if now_ts is None else now_ts now = time.time() if now_ts is None else now_ts
stale = [cid for cid, until in _runtime_cooldown_until.items() if until <= now] stale = [cid for cid, until in _runtime_cooldown_until.items() if until <= now]
@ -186,15 +196,19 @@ def _cfg_supports_image_input(cfg: dict) -> bool:
else None else None
) )
return derive_supports_image_input( return derive_supports_image_input(
provider=cfg.get("provider"), litellm_provider=cfg.get("litellm_provider"),
model_name=cfg.get("model_name"), model_name=cfg.get("model_name"),
base_model=base_model, base_model=base_model,
custom_provider=cfg.get("custom_provider"), custom_provider=cfg.get("custom_provider"),
) )
def _global_candidates(*, requires_image_input: bool = False) -> list[dict]: def _global_candidates(
"""Return Auto-eligible global cfgs. *,
capability: str = "chat",
requires_image_input: bool = False,
) -> list[dict]:
"""Return Auto-eligible global virtual models.
Drops cfgs flagged ``health_gated`` (best non-null OpenRouter uptime Drops cfgs flagged ``health_gated`` (best non-null OpenRouter uptime
below ``_HEALTH_GATE_UPTIME_PCT``) so chronically broken providers below ``_HEALTH_GATE_UPTIME_PCT``) so chronically broken providers
@ -205,17 +219,135 @@ def _global_candidates(*, requires_image_input: bool = False) -> list[dict]:
filters out configs whose ``supports_image_input`` resolves to False filters out configs whose ``supports_image_input`` resolves to False
so a text-only deployment can't be pinned for an image request. so a text-only deployment can't be pinned for an image request.
""" """
candidates = [ connection_by_id = {
cfg int(conn.get("id")): conn
for conn in config.GLOBAL_CONNECTIONS
if conn.get("id") is not None
}
config_by_model_name = {
cfg.get("model_name"): cfg
for cfg in config.GLOBAL_LLM_CONFIGS for cfg in config.GLOBAL_LLM_CONFIGS
if _is_usable_global_config(cfg) if _is_usable_global_config(cfg)
and not cfg.get("health_gated") }
and not _is_runtime_cooled_down(int(cfg.get("id", 0))) candidates: list[dict] = []
and (not requires_image_input or _cfg_supports_image_input(cfg)) for model in config.GLOBAL_MODELS:
] model_id = int(model.get("id", 0))
if model_id >= 0 or _is_runtime_cooled_down(model_id):
continue
if not _has_capability(model, capability):
continue
cfg = config_by_model_name.get(model.get("model_id")) or {}
if cfg.get("health_gated"):
continue
if requires_image_input and not _has_capability(model, "vision"):
continue
if requires_image_input and cfg and not _cfg_supports_image_input(cfg):
continue
connection = connection_by_id.get(int(model.get("connection_id", 0)))
if not connection:
continue
catalog = model.get("catalog") or {}
candidates.append(
{
"id": model_id,
"model_id": model.get("model_id"),
"source": "global",
"connection": connection,
"capabilities": model.get("capabilities") or {},
"billing_tier": model.get("billing_tier", "free"),
"litellm_provider": connection.get("litellm_provider"),
"model_name": model.get("model_id"),
"auto_pin_tier": catalog.get("auto_pin_tier")
or cfg.get("auto_pin_tier")
or "A",
"quality_score": catalog.get("quality_score")
or cfg.get("quality_score")
or cfg.get("quality_score_static")
or 50,
}
)
return sorted(candidates, key=lambda c: int(c.get("id", 0))) return sorted(candidates, key=lambda c: int(c.get("id", 0)))
async def _db_candidates(
session: AsyncSession,
*,
search_space_id: int,
user_id: str | UUID | None,
capability: str,
requires_image_input: bool = False,
) -> list[dict]:
parsed_user_id = _to_uuid(user_id)
stmt = (
select(Model)
.options(selectinload(Model.connection))
.join(Connection, Model.connection_id == Connection.id)
.where(Model.enabled.is_(True), Connection.enabled.is_(True))
)
result = await session.execute(stmt)
candidates: list[dict] = []
for model in result.scalars().all():
conn = model.connection
if not conn:
continue
if conn.search_space_id is not None and conn.search_space_id != search_space_id:
continue
if conn.user_id is not None and parsed_user_id is not None and conn.user_id != parsed_user_id:
continue
if conn.user_id is not None and parsed_user_id is None:
continue
if not _has_capability(model, capability):
continue
if requires_image_input and not _has_capability(model, "vision"):
continue
model_id = int(model.id)
if _is_runtime_cooled_down(model_id):
continue
catalog = model.catalog or {}
candidates.append(
{
"id": model_id,
"model_id": model.model_id,
"source": "db",
"connection": conn,
"capabilities": model.capabilities or {},
"billing_tier": "byok",
"litellm_provider": conn.litellm_provider,
"model_name": model.model_id,
"auto_pin_tier": catalog.get("auto_pin_tier") or "BYOK",
"quality_score": catalog.get("quality_score") or 75,
}
)
return sorted(candidates, key=lambda c: int(c.get("id", 0)))
async def auto_model_candidates(
session: AsyncSession,
*,
search_space_id: int,
user_id: str | UUID | None,
capability: str,
requires_image_input: bool = False,
exclude_model_ids: set[int] | None = None,
) -> list[dict]:
excluded_ids = {int(mid) for mid in (exclude_model_ids or set())}
db_candidates = await _db_candidates(
session,
search_space_id=search_space_id,
user_id=user_id,
capability=capability,
requires_image_input=requires_image_input,
)
candidates = [
*_global_candidates(
capability=capability,
requires_image_input=requires_image_input,
),
*db_candidates,
]
return [c for c in candidates if int(c.get("id", 0)) not in excluded_ids]
def _tier_of(cfg: dict) -> str: def _tier_of(cfg: dict) -> str:
return str(cfg.get("billing_tier", "free")).lower() return str(cfg.get("billing_tier", "free")).lower()
@ -223,8 +355,9 @@ def _tier_of(cfg: dict) -> str:
def _is_preferred_premium_auto_config(cfg: dict) -> bool: def _is_preferred_premium_auto_config(cfg: dict) -> bool:
"""Return True for the operator-preferred premium Auto model.""" """Return True for the operator-preferred premium Auto model."""
return ( return (
_tier_of(cfg) == "premium" cfg.get("source") == "global"
and str(cfg.get("provider", "")).upper() == "AZURE_OPENAI" and _tier_of(cfg) == "premium"
and str(cfg.get("litellm_provider", "")).lower() == "azure"
and str(cfg.get("model_name", "")).lower() == "gpt-5.4" and str(cfg.get("model_name", "")).lower() == "gpt-5.4"
) )
@ -251,6 +384,11 @@ def _select_pin(eligible: list[dict], thread_id: int) -> tuple[dict, int]:
return top_k[idx], len(top_k) return top_k[idx], len(top_k)
def choose_auto_model_candidate(candidates: list[dict], seed_id: int) -> dict:
selected, _ = _select_pin(candidates, seed_id)
return selected
def _to_uuid(user_id: str | UUID | None) -> UUID | None: def _to_uuid(user_id: str | UUID | None) -> UUID | None:
if user_id is None: if user_id is None:
return None return None
@ -326,20 +464,23 @@ async def resolve_or_get_pinned_llm_config_id(
) )
excluded_ids = {int(cid) for cid in (exclude_config_ids or set())} excluded_ids = {int(cid) for cid in (exclude_config_ids or set())}
candidates = [ candidates = await auto_model_candidates(
c session,
for c in _global_candidates(requires_image_input=requires_image_input) search_space_id=search_space_id,
if int(c.get("id", 0)) not in excluded_ids user_id=user_id,
] capability="chat",
requires_image_input=requires_image_input,
exclude_model_ids=excluded_ids,
)
if not candidates: if not candidates:
if requires_image_input: if requires_image_input:
# Distinguish the "no vision-capable cfg" case from generic # Distinguish the "no vision-capable cfg" case from generic
# "no usable cfg" so the streaming task can map this to the # "no usable cfg" so the streaming task can map this to the
# MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT SSE error. # MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT SSE error.
raise ValueError( raise ValueError(
"No vision-capable global LLM configs are available for Auto mode" "No vision-capable LLM models are available for Auto mode"
) )
raise ValueError("No usable global LLM configs are available for Auto mode") raise ValueError("No usable LLM models are available for Auto mode")
candidate_by_id = {int(c["id"]): c for c in candidates} candidate_by_id = {int(c["id"]): c for c in candidates}
# Reuse an existing valid pin without re-checking current quota (no silent # Reuse an existing valid pin without re-checking current quota (no silent
@ -379,24 +520,13 @@ async def resolve_or_get_pinned_llm_config_id(
# log that explicitly so operators can correlate the re-pin with # log that explicitly so operators can correlate the re-pin with
# the user's image attachment instead of suspecting a cooldown. # the user's image attachment instead of suspecting a cooldown.
if requires_image_input: if requires_image_input:
try: logger.info(
pinned_global = next( "auto_pin_repinned_for_image thread_id=%s search_space_id=%s "
c "previous_config_id=%s",
for c in config.GLOBAL_LLM_CONFIGS thread_id,
if int(c.get("id", 0)) == int(pinned_id) search_space_id,
) pinned_id,
except StopIteration: )
pinned_global = None
if pinned_global is not None and not _cfg_supports_image_input(
pinned_global
):
logger.info(
"auto_pin_repinned_for_image thread_id=%s search_space_id=%s "
"previous_config_id=%s",
thread_id,
search_space_id,
pinned_id,
)
logger.info( logger.info(
"auto_pin_invalid thread_id=%s search_space_id=%s pinned_config_id=%s", "auto_pin_invalid thread_id=%s search_space_id=%s pinned_config_id=%s",
thread_id, thread_id,

View file

@ -30,11 +30,7 @@ from litellm.exceptions import (
) )
from pydantic import Field from pydantic import Field
from app.services.model_resolver import ( from app.services.model_resolver import native_connection_from_config, to_litellm
NATIVE_PROVIDER_PREFIX,
native_connection_from_config,
to_litellm,
)
from app.utils.perf import get_perf_logger from app.utils.perf import get_perf_logger
litellm.json_logs = False litellm.json_logs = False
@ -101,10 +97,6 @@ def _sanitize_content(content: Any) -> Any:
# Special ID for Auto mode - uses router for load balancing # Special ID for Auto mode - uses router for load balancing
AUTO_MODE_ID = 0 AUTO_MODE_ID = 0
# Historical export kept for callers that still import PROVIDER_MAP.
PROVIDER_MAP = NATIVE_PROVIDER_PREFIX
class LLMRouterService: class LLMRouterService:
""" """
Singleton service for managing LiteLLM Router. Singleton service for managing LiteLLM Router.

View file

@ -10,13 +10,11 @@ from sqlalchemy.orm import selectinload
from app.config import config from app.config import config
from app.db import Model, SearchSpace from app.db import Model, SearchSpace
from app.services.llm_router_service import ( from app.services.auto_model_pin_service import (
AUTO_MODE_ID, auto_model_candidates,
ChatLiteLLMRouter, choose_auto_model_candidate,
LLMRouterService,
get_auto_mode_llm,
is_auto_mode,
) )
from app.services.llm_router_service import AUTO_MODE_ID, ChatLiteLLMRouter, is_auto_mode
from app.services.model_resolver import native_connection_from_config, to_litellm from app.services.model_resolver import native_connection_from_config, to_litellm
from app.services.token_tracking_service import token_tracker from app.services.token_tracking_service import token_tracker
@ -78,7 +76,7 @@ def _legacy_config_connection(
api_version: str | None = None, api_version: str | None = None,
) -> tuple[str, dict]: ) -> tuple[str, dict]:
cfg = { cfg = {
"provider": provider, "litellm_provider": provider.lower(),
"model_name": model_name, "model_name": model_name,
"api_key": api_key, "api_key": api_key,
"api_base": api_base, "api_base": api_base,
@ -325,23 +323,21 @@ async def get_search_space_llm_instance(
logger.error(f"No {role} LLM configured for search space {search_space_id}") logger.error(f"No {role} LLM configured for search space {search_space_id}")
return None return None
# Check for Auto mode (ID 0) - use router for load balancing # Auto mode resolves to one concrete global or BYOK model from the
# unified model-connections catalog.
if is_auto_mode(llm_config_id): if is_auto_mode(llm_config_id):
if not LLMRouterService.is_initialized(): candidates = await auto_model_candidates(
logger.error( session,
"Auto mode requested but LLM Router not initialized. " search_space_id=search_space_id,
"Ensure global_llm_config.yaml exists with valid configs." user_id=search_space.user_id,
) capability="chat",
return None )
if not candidates:
try: logger.error("No chat-capable models available for Auto mode")
logger.debug(
f"Using Auto mode (LLM Router) for search space {search_space_id}, role {role}"
)
return get_auto_mode_llm(streaming=not disable_streaming)
except Exception as e:
logger.error(f"Failed to create ChatLiteLLMRouter: {e}")
return None return None
llm_config_id = int(
choose_auto_model_candidate(candidates, search_space_id)["id"]
)
# Check if this is a global virtual model (negative ID) # Check if this is a global virtual model (negative ID)
if llm_config_id < 0: if llm_config_id < 0:
@ -414,7 +410,7 @@ async def get_vision_llm(
"""Get the search space's vision LLM instance for screenshot analysis. """Get the search space's vision LLM instance for screenshot analysis.
Resolves from the new connection/model role bindings: Resolves from the new connection/model role bindings:
- Auto mode (ID 0): VisionLLMRouterService - Auto mode (ID 0): unified global/BYOK model candidate selection
- Global (negative ID): virtual GLOBAL models from YAML - Global (negative ID): virtual GLOBAL models from YAML
- DB (positive ID): Model + Connection tables - DB (positive ID): Model + Connection tables
@ -424,10 +420,7 @@ async def get_vision_llm(
unwrapped they don't consume premium credit (issue M). unwrapped they don't consume premium credit (issue M).
""" """
from app.services.quota_checked_vision_llm import QuotaCheckedVisionLLM from app.services.quota_checked_vision_llm import QuotaCheckedVisionLLM
from app.services.vision_llm_router_service import ( from app.services.vision_llm_router_service import is_vision_auto_mode
VisionLLMRouterService,
is_vision_auto_mode,
)
try: try:
result = await session.execute( result = await session.execute(
@ -476,26 +469,16 @@ async def get_vision_llm(
return None return None
if is_vision_auto_mode(config_id): if is_vision_auto_mode(config_id):
if not VisionLLMRouterService.is_initialized(): candidates = await auto_model_candidates(
logger.error( session,
"Vision Auto mode requested but Vision LLM Router not initialized" search_space_id=search_space_id,
) user_id=owner_user_id,
return None capability="vision",
try: )
# Auto mode is currently treated as free at the wrapper if not candidates:
# level — the underlying router can dispatch to either logger.error("No vision-capable models available for Auto mode")
# premium or free YAML configs but routing decisions are
# opaque. If/when we want to bill Auto-routed vision
# calls we'd need to thread the resolved deployment's
# billing_tier back from the router. For now we keep
# parity with chat Auto, which also doesn't pre-classify.
return ChatLiteLLMRouter(
router=VisionLLMRouterService.get_router(),
streaming=True,
)
except Exception as e:
logger.error(f"Failed to create vision ChatLiteLLMRouter: {e}")
return None return None
config_id = int(choose_auto_model_candidate(candidates, search_space_id)["id"])
if config_id < 0: if config_id < 0:
global_model = get_global_model(config_id) global_model = get_global_model(config_id)

View file

@ -154,19 +154,19 @@ def _process_models(raw_models: list[dict]) -> list[dict]:
} }
) )
# 2) Emit for the native provider when we have a mapping # 2) Emit for the direct provider when we have a mapping
native_provider = OPENROUTER_SLUG_TO_PROVIDER.get(provider_slug) direct_provider = OPENROUTER_SLUG_TO_PROVIDER.get(provider_slug)
if native_provider: if direct_provider:
# Google's Gemini API only serves gemini-* models. # Google's Gemini API only serves gemini-* models.
# Open-source models like gemma-* are NOT available through it. # Open-source models like gemma-* are NOT available through it.
if native_provider == "GOOGLE" and not model_name.startswith("gemini-"): if direct_provider == "GOOGLE" and not model_name.startswith("gemini-"):
continue continue
processed.append( processed.append(
{ {
"value": model_name, "value": model_name,
"label": name, "label": name,
"provider": native_provider, "provider": direct_provider,
"context_window": context_window, "context_window": context_window,
} }
) )

View file

@ -9,53 +9,12 @@ from __future__ import annotations
from collections.abc import Mapping from collections.abc import Mapping
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, Any
from app.services.provider_api_base import resolve_api_base
if TYPE_CHECKING: if TYPE_CHECKING:
from app.db import Connection from app.db import Connection
PROTOCOL_OLLAMA = "OLLAMA" PROTOCOL_OLLAMA = "OLLAMA"
PROTOCOL_OPENAI_COMPATIBLE = "OPENAI_COMPATIBLE" PROTOCOL_OPENAI_COMPATIBLE = "OPENAI_COMPATIBLE"
PROTOCOL_NATIVE = "NATIVE" PROTOCOL_ANTHROPIC = "ANTHROPIC"
NATIVE_PROVIDER_PREFIX: dict[str, str] = {
"OPENAI": "openai",
"ANTHROPIC": "anthropic",
"GROQ": "groq",
"COHERE": "cohere",
"GOOGLE": "gemini",
"MISTRAL": "mistral",
"AZURE_OPENAI": "azure",
"AZURE": "azure",
"OPENROUTER": "openrouter",
"COMETAPI": "cometapi",
"XAI": "xai",
"BEDROCK": "bedrock",
"AWS_BEDROCK": "bedrock",
"VERTEX_AI": "vertex_ai",
"TOGETHER_AI": "together_ai",
"FIREWORKS_AI": "fireworks_ai",
"DEEPSEEK": "openai",
"ALIBABA_QWEN": "openai",
"MOONSHOT": "openai",
"ZHIPU": "openai",
"GITHUB_MODELS": "github",
"REPLICATE": "replicate",
"PERPLEXITY": "perplexity",
"ANYSCALE": "anyscale",
"DEEPINFRA": "deepinfra",
"CEREBRAS": "cerebras",
"SAMBANOVA": "sambanova",
"AI21": "ai21",
"CLOUDFLARE": "cloudflare",
"DATABRICKS": "databricks",
"HUGGINGFACE": "huggingface",
"MINIMAX": "openai",
"RECRAFT": "recraft",
"XINFERENCE": "xinference",
"NSCALE": "nscale",
"CUSTOM": "custom",
}
def ensure_v1(base_url: str | None) -> str | None: def ensure_v1(base_url: str | None) -> str | None:
@ -77,6 +36,23 @@ def _protocol_value(protocol: Any) -> str:
return getattr(protocol, "value", str(protocol)) return getattr(protocol, "value", str(protocol))
def default_litellm_provider(protocol: Any) -> str:
protocol_value = _protocol_value(protocol)
defaults = {
PROTOCOL_OLLAMA: "ollama_chat",
PROTOCOL_ANTHROPIC: "anthropic",
PROTOCOL_OPENAI_COMPATIBLE: "openai",
}
return defaults.get(protocol_value, "openai")
def _execution_api_base(protocol: str, base_url: str | None) -> str | None:
del protocol
if not base_url:
return None
return base_url.rstrip("/")
def to_litellm( def to_litellm(
conn: Connection | Mapping[str, Any], conn: Connection | Mapping[str, Any],
model_id: str, model_id: str,
@ -85,38 +61,19 @@ def to_litellm(
protocol = _protocol_value(_conn_value(conn, "protocol")) protocol = _protocol_value(_conn_value(conn, "protocol"))
base_url = _conn_value(conn, "base_url") base_url = _conn_value(conn, "base_url")
api_key = _conn_value(conn, "api_key") api_key = _conn_value(conn, "api_key")
native_provider = _conn_value(conn, "native_provider") litellm_provider = (
_conn_value(conn, "litellm_provider") or default_litellm_provider(protocol)
)
extra = _conn_value(conn, "extra") or {} extra = _conn_value(conn, "extra") or {}
kwargs: dict[str, Any] = {} kwargs: dict[str, Any] = {}
if api_key: if api_key:
kwargs["api_key"] = api_key kwargs["api_key"] = api_key
if protocol == PROTOCOL_OLLAMA: model_string = f"{litellm_provider}/{model_id}" if litellm_provider else model_id
model_string = f"ollama_chat/{model_id}" api_base = _execution_api_base(protocol, base_url)
if base_url: if api_base:
kwargs["api_base"] = base_url.rstrip("/") kwargs["api_base"] = api_base
elif protocol == PROTOCOL_OPENAI_COMPATIBLE:
model_string = f"openai/{model_id}"
api_base = ensure_v1(base_url)
if api_base:
kwargs["api_base"] = api_base
else:
provider_key = (native_provider or "").upper()
prefix = NATIVE_PROVIDER_PREFIX.get(provider_key, provider_key.lower())
if prefix == "custom":
custom_provider = extra.get("custom_provider") or native_provider
model_string = f"{custom_provider}/{model_id}" if custom_provider else model_id
else:
model_string = f"{prefix}/{model_id}"
api_base = resolve_api_base(
provider=provider_key,
provider_prefix=prefix,
config_api_base=base_url,
)
if api_base:
kwargs["api_base"] = api_base
if api_version := extra.get("api_version"): if api_version := extra.get("api_version"):
kwargs["api_version"] = api_version kwargs["api_version"] = api_version
@ -126,18 +83,21 @@ def to_litellm(
def native_connection_from_config(config: Mapping[str, Any]) -> dict[str, Any]: def native_connection_from_config(config: Mapping[str, Any]) -> dict[str, Any]:
"""Build an in-memory NATIVE connection mapping from a legacy/global config.""" """Build an in-memory connection mapping from a global config."""
provider = str(config.get("provider") or config.get("custom_provider") or "CUSTOM") protocol = str(config.get("protocol") or PROTOCOL_OPENAI_COMPATIBLE)
litellm_provider = str(
config.get("litellm_provider")
or config.get("custom_provider")
or default_litellm_provider(protocol)
)
extra: dict[str, Any] = { extra: dict[str, Any] = {
"litellm_params": config.get("litellm_params") or {}, "litellm_params": config.get("litellm_params") or {},
} }
if config.get("api_version"): if config.get("api_version"):
extra["api_version"] = config.get("api_version") extra["api_version"] = config.get("api_version")
if config.get("custom_provider"):
extra["custom_provider"] = config.get("custom_provider")
return { return {
"protocol": PROTOCOL_NATIVE, "protocol": protocol,
"native_provider": provider, "litellm_provider": litellm_provider,
"base_url": config.get("api_base") or None, "base_url": config.get("api_base") or None,
"api_key": config.get("api_key") or None, "api_key": config.get("api_key") or None,
"extra": extra, "extra": extra,
@ -145,7 +105,7 @@ def native_connection_from_config(config: Mapping[str, Any]) -> dict[str, Any]:
__all__ = [ __all__ = [
"NATIVE_PROVIDER_PREFIX", "default_litellm_provider",
"ensure_v1", "ensure_v1",
"native_connection_from_config", "native_connection_from_config",
"to_litellm", "to_litellm",

View file

@ -3,19 +3,12 @@ from typing import Any
from litellm import Router from litellm import Router
from app.services.model_resolver import ( from app.services.model_resolver import native_connection_from_config, to_litellm
NATIVE_PROVIDER_PREFIX,
native_connection_from_config,
to_litellm,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
VISION_AUTO_MODE_ID = 0 VISION_AUTO_MODE_ID = 0
VISION_PROVIDER_MAP = NATIVE_PROVIDER_PREFIX
class VisionLLMRouterService: class VisionLLMRouterService:
_instance = None _instance = None
_router: Router | None = None _router: Router | None = None
@ -141,12 +134,11 @@ def is_vision_auto_mode(config_id: int | None) -> bool:
def build_vision_model_string( def build_vision_model_string(
provider: str, model_name: str, custom_provider: str | None litellm_provider: str, model_name: str, custom_provider: str | None
) -> str: ) -> str:
if custom_provider: if custom_provider:
return f"{custom_provider}/{model_name}" return f"{custom_provider}/{model_name}"
prefix = VISION_PROVIDER_MAP.get(provider.upper(), provider.lower()) return f"{litellm_provider}/{model_name}"
return f"{prefix}/{model_name}"
def get_global_vision_llm_config(config_id: int) -> dict | None: def get_global_vision_llm_config(config_id: int) -> dict | None:

View file

@ -97,16 +97,16 @@ def _process_vision_models(raw_models: list[dict]) -> list[dict]:
} }
) )
native_provider = OPENROUTER_SLUG_TO_VISION_PROVIDER.get(provider_slug) direct_provider = OPENROUTER_SLUG_TO_VISION_PROVIDER.get(provider_slug)
if native_provider: if direct_provider:
if native_provider == "GOOGLE" and not model_name.startswith("gemini-"): if direct_provider == "GOOGLE" and not model_name.startswith("gemini-"):
continue continue
processed.append( processed.append(
{ {
"value": model_name, "value": model_name,
"label": name, "label": name,
"provider": native_provider, "provider": direct_provider,
"context_window": context_window, "context_window": context_window,
} }
) )

View file

@ -40,7 +40,7 @@ def check_image_input_capability(
else None else None
) )
if not is_known_text_only_chat_model( if not is_known_text_only_chat_model(
provider=agent_config.provider, litellm_provider=agent_config.provider,
model_name=agent_config.model_name, model_name=agent_config.model_name,
base_model=agent_base_model, base_model=agent_base_model,
custom_provider=agent_config.custom_provider, custom_provider=agent_config.custom_provider,

View file

@ -80,7 +80,6 @@ async def _generate_title(
from litellm import acompletion from litellm import acompletion
from app.services.llm_router_service import LLMRouterService from app.services.llm_router_service import LLMRouterService
from app.services.provider_api_base import resolve_api_base
from app.services.token_tracking_service import _turn_accumulator from app.services.token_tracking_service import _turn_accumulator
# Excludes this turn's own assistant row (pre-written by # Excludes this turn's own assistant row (pre-written by
@ -125,26 +124,12 @@ async def _generate_title(
router = LLMRouterService.get_router() router = LLMRouterService.get_router()
response = await router.acompletion(model="auto", messages=messages) response = await router.acompletion(model="auto", messages=messages)
else: else:
# Apply the same ``api_base`` cascade chat / vision / image-gen
# call sites use so we never inherit ``litellm.api_base``
# (commonly set by ``AZURE_OPENAI_ENDPOINT``) when the chat
# config itself ships an empty ``api_base``. Without this the
# title-gen on an OpenRouter chat config would 404 against the
# inherited Azure endpoint — see ``provider_api_base`` for the
# same bug repro on the image-gen / vision paths.
raw_model = getattr(llm, "model", "") or "" raw_model = getattr(llm, "model", "") or ""
provider_prefix = raw_model.split("/", 1)[0] if "/" in raw_model else None
provider_value = agent_config.provider if agent_config is not None else None
title_api_base = resolve_api_base(
provider=provider_value,
provider_prefix=provider_prefix,
config_api_base=getattr(llm, "api_base", None),
)
response = await acompletion( response = await acompletion(
model=raw_model, model=raw_model,
messages=messages, messages=messages,
api_key=getattr(llm, "api_key", None), api_key=getattr(llm, "api_key", None),
api_base=title_api_base, api_base=getattr(llm, "api_base", None),
) )
usage_info = None usage_info = None

View file

@ -1,8 +1,8 @@
"""Load an LLM + AgentConfig bundle for a given config id. """Load an LLM + AgentConfig bundle for a given config id.
Handles both code paths uniformly: Handles both code paths uniformly:
- ``config_id >= 0`` database-backed ``NewLLMConfig`` row (per-user/per-space). - ``config_id > 0`` database-backed model-connection ``Model`` row.
- ``config_id < 0`` YAML-defined global LLM config (built-in defaults). - ``config_id < 0`` virtual global model materialized from YAML/OpenRouter.
Returns ``(llm, agent_config, error_message)``; on success ``error_message`` is Returns ``(llm, agent_config, error_message)``; on success ``error_message`` is
``None``. The caller emits the friendly SSE error frame. ``None``. The caller emits the friendly SSE error frame.
@ -12,15 +12,72 @@ from __future__ import annotations
from typing import Any from typing import Any
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from app.agents.chat.runtime.llm_config import ( from app.agents.chat.runtime.llm_config import (
AgentConfig, AgentConfig,
create_chat_litellm_from_agent_config, SanitizedChatLiteLLM,
create_chat_litellm_from_config,
load_agent_config,
load_global_llm_config_by_id,
) )
from app.config import config
from app.db import Model, SearchSpace
from app.services.model_resolver import to_litellm
def _agent_config_from_resolved(
*,
config_id: int,
config_name: str | None,
provider: str,
model_name: str,
api_key: str | None,
api_base: str | None,
litellm_params: dict | None,
supports_image_input: bool,
billing_tier: str = "free",
) -> AgentConfig:
return AgentConfig(
provider=provider,
model_name=model_name,
api_key=api_key or "",
api_base=api_base,
custom_provider=None,
litellm_params=litellm_params,
config_id=config_id,
config_name=config_name,
is_auto_mode=False,
billing_tier=billing_tier,
is_premium=billing_tier == "premium",
supports_image_input=supports_image_input,
)
async def _load_search_space(session: AsyncSession, search_space_id: int) -> SearchSpace | None:
result = await session.execute(select(SearchSpace).where(SearchSpace.id == search_space_id))
return result.scalars().first()
async def _load_db_model(
session: AsyncSession,
*,
model_id: int,
search_space: SearchSpace,
) -> Model | None:
result = await session.execute(
select(Model)
.options(selectinload(Model.connection))
.where(Model.id == model_id, Model.enabled.is_(True))
)
model = result.scalars().first()
if not model or not model.connection or not model.connection.enabled:
return None
conn = model.connection
if conn.search_space_id is not None and conn.search_space_id != search_space.id:
return None
if conn.user_id is not None and conn.user_id != search_space.user_id:
return None
return model
async def load_llm_bundle( async def load_llm_bundle(
@ -29,29 +86,67 @@ async def load_llm_bundle(
config_id: int, config_id: int,
search_space_id: int, search_space_id: int,
) -> tuple[Any, AgentConfig | None, str | None]: ) -> tuple[Any, AgentConfig | None, str | None]:
if config_id >= 0: search_space = await _load_search_space(session, search_space_id)
loaded_agent_config = await load_agent_config( if not search_space:
session=session, return None, None, f"Search space {search_space_id} not found"
config_id=config_id,
search_space_id=search_space_id, if config_id > 0:
model = await _load_db_model(
session,
model_id=config_id,
search_space=search_space,
) )
if not loaded_agent_config: if not model or not (model.capabilities or {}).get("chat"):
return ( return (
None, None,
None, None,
f"Failed to load NewLLMConfig with id {config_id}", f"Failed to load chat model with id {config_id}",
) )
model_string, litellm_kwargs = to_litellm(model.connection, model.model_id)
agent_config = _agent_config_from_resolved(
config_id=config_id,
config_name=model.display_name or model.model_id,
provider=model.connection.litellm_provider or "",
model_name=model.model_id,
api_key=model.connection.api_key,
api_base=model.connection.base_url,
litellm_params=(model.connection.extra or {}).get("litellm_params"),
supports_image_input=bool((model.capabilities or {}).get("vision")),
billing_tier="free",
)
return ( return (
create_chat_litellm_from_agent_config(loaded_agent_config), SanitizedChatLiteLLM(model=model_string, **litellm_kwargs),
loaded_agent_config, agent_config,
None, None,
) )
loaded_llm_config = load_global_llm_config_by_id(config_id) global_model = next((m for m in config.GLOBAL_MODELS if m.get("id") == config_id), None)
if not loaded_llm_config: if not global_model or not (global_model.get("capabilities") or {}).get("chat"):
return None, None, f"Failed to load LLM config with id {config_id}" return None, None, f"Failed to load global chat model with id {config_id}"
return ( global_connection = next(
create_chat_litellm_from_config(loaded_llm_config), (
AgentConfig.from_yaml_config(loaded_llm_config), c
for c in config.GLOBAL_CONNECTIONS
if c.get("id") == global_model.get("connection_id")
),
None,
)
if not global_connection:
return None, None, f"Failed to load global connection for model {config_id}"
model_string, litellm_kwargs = to_litellm(global_connection, global_model["model_id"])
agent_config = _agent_config_from_resolved(
config_id=config_id,
config_name=global_model.get("display_name") or global_model.get("model_id"),
provider=global_connection.get("litellm_provider") or "",
model_name=global_model["model_id"],
api_key=global_connection.get("api_key"),
api_base=global_connection.get("base_url"),
litellm_params=(global_connection.get("extra") or {}).get("litellm_params"),
supports_image_input=bool((global_model.get("capabilities") or {}).get("vision")),
billing_tier=str(global_model.get("billing_tier", "free")).lower(),
)
return (
SanitizedChatLiteLLM(model=model_string, **litellm_kwargs),
agent_config,
None, None,
) )

View file

@ -75,10 +75,10 @@ async def test_auto_first_turn_pins_one_model(monkeypatch):
config, config,
"GLOBAL_LLM_CONFIGS", "GLOBAL_LLM_CONFIGS",
[ [
{"id": -2, "provider": "OPENAI", "model_name": "gpt-free", "api_key": "k1"}, {"id": -2, "litellm_provider": "openai", "model_name": "gpt-free", "api_key": "k1"},
{ {
"id": -1, "id": -1,
"provider": "OPENAI", "litellm_provider": "openai",
"model_name": "gpt-prem", "model_name": "gpt-prem",
"api_key": "k2", "api_key": "k2",
"billing_tier": "premium", "billing_tier": "premium",
@ -117,7 +117,7 @@ async def test_premium_eligible_auto_prefers_premium_over_free(monkeypatch):
[ [
{ {
"id": -2, "id": -2,
"provider": "OPENAI", "litellm_provider": "openai",
"model_name": "gpt-free", "model_name": "gpt-free",
"api_key": "k1", "api_key": "k1",
"billing_tier": "free", "billing_tier": "free",
@ -125,7 +125,7 @@ async def test_premium_eligible_auto_prefers_premium_over_free(monkeypatch):
}, },
{ {
"id": -1, "id": -1,
"provider": "OPENAI", "litellm_provider": "openai",
"model_name": "gpt-prem", "model_name": "gpt-prem",
"api_key": "k2", "api_key": "k2",
"billing_tier": "premium", "billing_tier": "premium",
@ -164,7 +164,7 @@ async def test_premium_eligible_auto_prefers_azure_gpt_5_4(monkeypatch):
[ [
{ {
"id": -1, "id": -1,
"provider": "AZURE_OPENAI", "litellm_provider": "azure",
"model_name": "gpt-5.1", "model_name": "gpt-5.1",
"api_key": "k1", "api_key": "k1",
"billing_tier": "premium", "billing_tier": "premium",
@ -173,7 +173,7 @@ async def test_premium_eligible_auto_prefers_azure_gpt_5_4(monkeypatch):
}, },
{ {
"id": -2, "id": -2,
"provider": "AZURE_OPENAI", "litellm_provider": "azure",
"model_name": "gpt-5.4", "model_name": "gpt-5.4",
"api_key": "k2", "api_key": "k2",
"billing_tier": "premium", "billing_tier": "premium",
@ -182,7 +182,7 @@ async def test_premium_eligible_auto_prefers_azure_gpt_5_4(monkeypatch):
}, },
{ {
"id": -3, "id": -3,
"provider": "OPENROUTER", "litellm_provider": "openrouter",
"model_name": "openai/gpt-5.4", "model_name": "openai/gpt-5.4",
"api_key": "k3", "api_key": "k3",
"billing_tier": "premium", "billing_tier": "premium",
@ -222,7 +222,7 @@ async def test_next_turn_reuses_existing_pin(monkeypatch):
[ [
{ {
"id": -1, "id": -1,
"provider": "OPENAI", "litellm_provider": "openai",
"model_name": "gpt-prem", "model_name": "gpt-prem",
"api_key": "k2", "api_key": "k2",
"billing_tier": "premium", "billing_tier": "premium",
@ -263,7 +263,7 @@ async def test_premium_eligible_auto_can_pin_premium(monkeypatch):
[ [
{ {
"id": -1, "id": -1,
"provider": "OPENAI", "litellm_provider": "openai",
"model_name": "gpt-prem", "model_name": "gpt-prem",
"api_key": "k2", "api_key": "k2",
"billing_tier": "premium", "billing_tier": "premium",
@ -301,14 +301,14 @@ async def test_premium_ineligible_auto_pins_free_only(monkeypatch):
[ [
{ {
"id": -2, "id": -2,
"provider": "OPENAI", "litellm_provider": "openai",
"model_name": "gpt-free", "model_name": "gpt-free",
"api_key": "k1", "api_key": "k1",
"billing_tier": "free", "billing_tier": "free",
}, },
{ {
"id": -1, "id": -1,
"provider": "OPENAI", "litellm_provider": "openai",
"model_name": "gpt-prem", "model_name": "gpt-prem",
"api_key": "k2", "api_key": "k2",
"billing_tier": "premium", "billing_tier": "premium",
@ -346,14 +346,14 @@ async def test_pinned_premium_stays_premium_after_quota_exhaustion(monkeypatch):
[ [
{ {
"id": -2, "id": -2,
"provider": "OPENAI", "litellm_provider": "openai",
"model_name": "gpt-free", "model_name": "gpt-free",
"api_key": "k1", "api_key": "k1",
"billing_tier": "free", "billing_tier": "free",
}, },
{ {
"id": -1, "id": -1,
"provider": "OPENAI", "litellm_provider": "openai",
"model_name": "gpt-prem", "model_name": "gpt-prem",
"api_key": "k2", "api_key": "k2",
"billing_tier": "premium", "billing_tier": "premium",
@ -391,14 +391,14 @@ async def test_force_repin_free_switches_auto_premium_pin_to_free(monkeypatch):
[ [
{ {
"id": -2, "id": -2,
"provider": "OPENAI", "litellm_provider": "openai",
"model_name": "gpt-free", "model_name": "gpt-free",
"api_key": "k1", "api_key": "k1",
"billing_tier": "free", "billing_tier": "free",
}, },
{ {
"id": -1, "id": -1,
"provider": "OPENAI", "litellm_provider": "openai",
"model_name": "gpt-prem", "model_name": "gpt-prem",
"api_key": "k2", "api_key": "k2",
"billing_tier": "premium", "billing_tier": "premium",
@ -437,7 +437,7 @@ async def test_explicit_user_model_change_clears_pin(monkeypatch):
config, config,
"GLOBAL_LLM_CONFIGS", "GLOBAL_LLM_CONFIGS",
[ [
{"id": -2, "provider": "OPENAI", "model_name": "gpt-free", "api_key": "k1"}, {"id": -2, "litellm_provider": "openai", "model_name": "gpt-free", "api_key": "k1"},
], ],
) )
@ -462,7 +462,7 @@ async def test_invalid_pinned_config_repairs_with_new_pin(monkeypatch):
config, config,
"GLOBAL_LLM_CONFIGS", "GLOBAL_LLM_CONFIGS",
[ [
{"id": -2, "provider": "OPENAI", "model_name": "gpt-free", "api_key": "k1"}, {"id": -2, "litellm_provider": "openai", "model_name": "gpt-free", "api_key": "k1"},
], ],
) )
@ -504,7 +504,7 @@ async def test_health_gated_config_is_excluded_from_selection(monkeypatch):
[ [
{ {
"id": -1, "id": -1,
"provider": "OPENROUTER", "litellm_provider": "openrouter",
"model_name": "venice/dead-model", "model_name": "venice/dead-model",
"api_key": "k1", "api_key": "k1",
"billing_tier": "free", "billing_tier": "free",
@ -514,7 +514,7 @@ async def test_health_gated_config_is_excluded_from_selection(monkeypatch):
}, },
{ {
"id": -2, "id": -2,
"provider": "OPENROUTER", "litellm_provider": "openrouter",
"model_name": "google/gemini-flash", "model_name": "google/gemini-flash",
"api_key": "k1", "api_key": "k1",
"billing_tier": "free", "billing_tier": "free",
@ -556,7 +556,7 @@ async def test_tier_a_locks_first_premium_user_skips_or(monkeypatch):
[ [
{ {
"id": -1, "id": -1,
"provider": "AZURE_OPENAI", "litellm_provider": "azure",
"model_name": "gpt-5", "model_name": "gpt-5",
"api_key": "k-yaml", "api_key": "k-yaml",
"billing_tier": "premium", "billing_tier": "premium",
@ -566,7 +566,7 @@ async def test_tier_a_locks_first_premium_user_skips_or(monkeypatch):
}, },
{ {
"id": -2, "id": -2,
"provider": "OPENROUTER", "litellm_provider": "openrouter",
"model_name": "openai/gpt-5", "model_name": "openai/gpt-5",
"api_key": "k-or", "api_key": "k-or",
"billing_tier": "premium", "billing_tier": "premium",
@ -608,7 +608,7 @@ async def test_tier_a_falls_through_to_or_when_a_pool_empty_for_user(monkeypatch
[ [
{ {
"id": -1, "id": -1,
"provider": "AZURE_OPENAI", "litellm_provider": "azure",
"model_name": "gpt-5", "model_name": "gpt-5",
"api_key": "k-yaml", "api_key": "k-yaml",
"billing_tier": "premium", "billing_tier": "premium",
@ -618,7 +618,7 @@ async def test_tier_a_falls_through_to_or_when_a_pool_empty_for_user(monkeypatch
}, },
{ {
"id": -2, "id": -2,
"provider": "OPENROUTER", "litellm_provider": "openrouter",
"model_name": "google/gemini-flash:free", "model_name": "google/gemini-flash:free",
"api_key": "k-or", "api_key": "k-or",
"billing_tier": "free", "billing_tier": "free",
@ -656,7 +656,7 @@ async def test_top_k_picks_only_high_score_models(monkeypatch):
high_score_cfgs = [ high_score_cfgs = [
{ {
"id": -i, "id": -i,
"provider": "AZURE_OPENAI", "litellm_provider": "azure",
"model_name": f"gpt-x-{i}", "model_name": f"gpt-x-{i}",
"api_key": "k", "api_key": "k",
"billing_tier": "premium", "billing_tier": "premium",
@ -668,7 +668,7 @@ async def test_top_k_picks_only_high_score_models(monkeypatch):
] ]
low_score_trap = { low_score_trap = {
"id": -99, "id": -99,
"provider": "AZURE_OPENAI", "litellm_provider": "azure",
"model_name": "tiny-legacy", "model_name": "tiny-legacy",
"api_key": "k", "api_key": "k",
"billing_tier": "premium", "billing_tier": "premium",
@ -729,7 +729,7 @@ async def test_pin_reuse_survives_health_gating_for_existing_pin(monkeypatch):
[ [
{ {
"id": -1, "id": -1,
"provider": "OPENROUTER", "litellm_provider": "openrouter",
"model_name": "venice/dead-model", "model_name": "venice/dead-model",
"api_key": "k", "api_key": "k",
"billing_tier": "premium", "billing_tier": "premium",
@ -739,7 +739,7 @@ async def test_pin_reuse_survives_health_gating_for_existing_pin(monkeypatch):
}, },
{ {
"id": -2, "id": -2,
"provider": "AZURE_OPENAI", "litellm_provider": "azure",
"model_name": "gpt-5", "model_name": "gpt-5",
"api_key": "k", "api_key": "k",
"billing_tier": "premium", "billing_tier": "premium",
@ -781,7 +781,7 @@ async def test_pin_reuse_regression_existing_healthy_pin(monkeypatch):
[ [
{ {
"id": -1, "id": -1,
"provider": "AZURE_OPENAI", "litellm_provider": "azure",
"model_name": "gpt-5", "model_name": "gpt-5",
"api_key": "k", "api_key": "k",
"billing_tier": "premium", "billing_tier": "premium",
@ -791,7 +791,7 @@ async def test_pin_reuse_regression_existing_healthy_pin(monkeypatch):
}, },
{ {
"id": -2, "id": -2,
"provider": "AZURE_OPENAI", "litellm_provider": "azure",
"model_name": "gpt-5-pro", "model_name": "gpt-5-pro",
"api_key": "k", "api_key": "k",
"billing_tier": "premium", "billing_tier": "premium",
@ -839,7 +839,7 @@ async def test_runtime_cooled_down_pin_is_not_reused(monkeypatch):
[ [
{ {
"id": -1, "id": -1,
"provider": "OPENROUTER", "litellm_provider": "openrouter",
"model_name": "google/gemma-4-26b-a4b-it:free", "model_name": "google/gemma-4-26b-a4b-it:free",
"api_key": "k", "api_key": "k",
"billing_tier": "free", "billing_tier": "free",
@ -849,7 +849,7 @@ async def test_runtime_cooled_down_pin_is_not_reused(monkeypatch):
}, },
{ {
"id": -2, "id": -2,
"provider": "OPENROUTER", "litellm_provider": "openrouter",
"model_name": "google/gemini-2.5-flash:free", "model_name": "google/gemini-2.5-flash:free",
"api_key": "k", "api_key": "k",
"billing_tier": "free", "billing_tier": "free",
@ -892,7 +892,7 @@ async def test_clearing_runtime_cooldown_restores_pin_reuse(monkeypatch):
[ [
{ {
"id": -1, "id": -1,
"provider": "OPENROUTER", "litellm_provider": "openrouter",
"model_name": "google/gemma-4-26b-a4b-it:free", "model_name": "google/gemma-4-26b-a4b-it:free",
"api_key": "k", "api_key": "k",
"billing_tier": "free", "billing_tier": "free",
@ -937,7 +937,7 @@ async def test_auto_pin_repin_excludes_previous_config_on_runtime_retry(monkeypa
[ [
{ {
"id": -1, "id": -1,
"provider": "OPENROUTER", "litellm_provider": "openrouter",
"model_name": "google/gemma-4-26b-a4b-it:free", "model_name": "google/gemma-4-26b-a4b-it:free",
"api_key": "k", "api_key": "k",
"billing_tier": "free", "billing_tier": "free",
@ -947,7 +947,7 @@ async def test_auto_pin_repin_excludes_previous_config_on_runtime_retry(monkeypa
}, },
{ {
"id": -2, "id": -2,
"provider": "OPENROUTER", "litellm_provider": "openrouter",
"model_name": "google/gemini-2.5-flash:free", "model_name": "google/gemini-2.5-flash:free",
"api_key": "k", "api_key": "k",
"billing_tier": "free", "billing_tier": "free",

View file

@ -74,7 +74,7 @@ def _thread(*, pinned: int | None = None):
def _vision_cfg(id_: int, *, tier: str = "free", quality: int = 80) -> dict: def _vision_cfg(id_: int, *, tier: str = "free", quality: int = 80) -> dict:
return { return {
"id": id_, "id": id_,
"provider": "OPENAI", "litellm_provider": "openai",
"model_name": f"vision-{id_}", "model_name": f"vision-{id_}",
"api_key": "k", "api_key": "k",
"billing_tier": tier, "billing_tier": tier,
@ -87,7 +87,7 @@ def _vision_cfg(id_: int, *, tier: str = "free", quality: int = 80) -> dict:
def _text_only_cfg(id_: int, *, tier: str = "free", quality: int = 90) -> dict: def _text_only_cfg(id_: int, *, tier: str = "free", quality: int = 90) -> dict:
return { return {
"id": id_, "id": id_,
"provider": "OPENAI", "litellm_provider": "openai",
"model_name": f"text-{id_}", "model_name": f"text-{id_}",
"api_key": "k", "api_key": "k",
"billing_tier": tier, "billing_tier": tier,
@ -261,7 +261,7 @@ async def test_image_turn_unannotated_cfg_resolves_via_helper(monkeypatch):
session = _FakeSession(_thread()) session = _FakeSession(_thread())
cfg_unannotated_vision = { cfg_unannotated_vision = {
"id": -2, "id": -2,
"provider": "OPENAI", "litellm_provider": "openai",
"model_name": "gpt-4o", # known vision model in LiteLLM map "model_name": "gpt-4o", # known vision model in LiteLLM map
"api_key": "k", "api_key": "k",
"billing_tier": "free", "billing_tier": "free",

View file

@ -25,10 +25,10 @@ def _fake_yaml_config(
return { return {
"id": id, "id": id,
"name": f"yaml-{id}", "name": f"yaml-{id}",
"provider": "OPENAI", "litellm_provider": "openai",
"model_name": model_name, "model_name": model_name,
"api_key": "sk-test", "api_key": "sk-test",
"api_base": "", "api_base": "https://api.openai.com/v1",
"billing_tier": billing_tier, "billing_tier": billing_tier,
"rpm": 100, "rpm": 100,
"tpm": 100_000, "tpm": 100_000,
@ -54,10 +54,10 @@ def _fake_openrouter_config(
return { return {
"id": id, "id": id,
"name": f"or-{id}", "name": f"or-{id}",
"provider": "OPENROUTER", "litellm_provider": "openrouter",
"model_name": model_name, "model_name": model_name,
"api_key": "sk-or-test", "api_key": "sk-or-test",
"api_base": "", "api_base": "https://openrouter.ai/api/v1",
"billing_tier": billing_tier, "billing_tier": billing_tier,
"rpm": 20 if billing_tier == "free" else 200, "rpm": 20 if billing_tier == "free" else 200,
"tpm": 100_000 if billing_tier == "free" else 1_000_000, "tpm": 100_000 if billing_tier == "free" else 1_000_000,

View file

@ -25,7 +25,7 @@ def _or_cfg(
) -> dict: ) -> dict:
return { return {
"id": cid, "id": cid,
"provider": "OPENROUTER", "litellm_provider": "openrouter",
"model_name": model_name, "model_name": model_name,
"billing_tier": tier, "billing_tier": tier,
"auto_pin_tier": "B" if tier == "premium" else "C", "auto_pin_tier": "B" if tier == "premium" else "C",
@ -144,7 +144,7 @@ async def test_enrich_health_only_touches_or_provider(monkeypatch):
"""YAML cfgs that aren't OPENROUTER must be skipped entirely.""" """YAML cfgs that aren't OPENROUTER must be skipped entirely."""
yaml_cfg = { yaml_cfg = {
"id": -1, "id": -1,
"provider": "AZURE_OPENAI", "litellm_provider": "azure",
"model_name": "gpt-5", "model_name": "gpt-5",
"billing_tier": "premium", "billing_tier": "premium",
"auto_pin_tier": "A", "auto_pin_tier": "A",
@ -313,7 +313,7 @@ async def test_enrich_health_no_or_cfgs_is_noop(monkeypatch):
"""When the catalogue has no OR cfgs at all, no HTTP calls fire.""" """When the catalogue has no OR cfgs at all, no HTTP calls fire."""
yaml_cfg: dict[str, Any] = { yaml_cfg: dict[str, Any] = {
"id": -1, "id": -1,
"provider": "AZURE_OPENAI", "litellm_provider": "azure",
"model_name": "gpt-5", "model_name": "gpt-5",
"billing_tier": "premium", "billing_tier": "premium",
} }

View file

@ -35,7 +35,7 @@ def test_safety_net_does_not_fire_for_azure_gpt_4o():
it text-only.""" it text-only."""
assert ( assert (
is_known_text_only_chat_model( is_known_text_only_chat_model(
provider="AZURE_OPENAI", litellm_provider="azure",
model_name="my-azure-deployment", model_name="my-azure-deployment",
base_model="gpt-4o", base_model="gpt-4o",
) )
@ -49,7 +49,7 @@ def test_safety_net_does_not_fire_for_unknown_model():
LiteLLM doesn't know about must flow through to the provider.""" LiteLLM doesn't know about must flow through to the provider."""
assert ( assert (
is_known_text_only_chat_model( is_known_text_only_chat_model(
provider="CUSTOM", litellm_provider="custom",
custom_provider="brand_new_proxy", custom_provider="brand_new_proxy",
model_name="brand-new-model-x9", model_name="brand-new-model-x9",
) )
@ -69,7 +69,7 @@ def test_safety_net_does_not_fire_when_lookup_raises(monkeypatch):
assert ( assert (
is_known_text_only_chat_model( is_known_text_only_chat_model(
provider="OPENAI", litellm_provider="openai",
model_name="gpt-4o", model_name="gpt-4o",
) )
is False is False
@ -88,7 +88,7 @@ def test_safety_net_fires_only_on_explicit_false(monkeypatch):
monkeypatch.setattr(pc.litellm, "get_model_info", _info_explicit_false) monkeypatch.setattr(pc.litellm, "get_model_info", _info_explicit_false)
assert ( assert (
is_known_text_only_chat_model( is_known_text_only_chat_model(
provider="OPENAI", litellm_provider="openai",
model_name="text-only-stub", model_name="text-only-stub",
) )
is True is True
@ -100,7 +100,7 @@ def test_safety_net_fires_only_on_explicit_false(monkeypatch):
monkeypatch.setattr(pc.litellm, "get_model_info", _info_true) monkeypatch.setattr(pc.litellm, "get_model_info", _info_true)
assert ( assert (
is_known_text_only_chat_model( is_known_text_only_chat_model(
provider="OPENAI", litellm_provider="openai",
model_name="vision-stub", model_name="vision-stub",
) )
is False is False
@ -112,7 +112,7 @@ def test_safety_net_fires_only_on_explicit_false(monkeypatch):
monkeypatch.setattr(pc.litellm, "get_model_info", _info_missing) monkeypatch.setattr(pc.litellm, "get_model_info", _info_missing)
assert ( assert (
is_known_text_only_chat_model( is_known_text_only_chat_model(
provider="OPENAI", litellm_provider="openai",
model_name="missing-key-stub", model_name="missing-key-stub",
) )
is False is False