mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-06-12 20:45:20 +02:00
feat(chat): route models by provider capabilities
This commit is contained in:
parent
8f20a32571
commit
c28c4f5785
18 changed files with 429 additions and 319 deletions
|
|
@ -2,9 +2,9 @@
|
|||
LLM configuration utilities for SurfSense agents.
|
||||
|
||||
This module provides functions for loading LLM configurations from:
|
||||
1. Auto mode (ID 0) - Uses LiteLLM Router for load balancing
|
||||
1. Auto mode (ID 0) - Resolved by callers to a concrete model-connection model
|
||||
2. YAML files (global configs with negative IDs)
|
||||
3. Database NewLLMConfig table (user-created configs with positive IDs)
|
||||
3. Database model-connections table (user-created configs with positive IDs)
|
||||
|
||||
It also provides utilities for creating ChatLiteLLM instances and
|
||||
managing prompt configurations.
|
||||
|
|
@ -33,9 +33,7 @@ from app.agents.chat.runtime.prompt_caching import (
|
|||
from app.services.llm_router_service import (
|
||||
AUTO_MODE_ID,
|
||||
ChatLiteLLMRouter,
|
||||
LLMRouterService,
|
||||
_sanitize_content,
|
||||
get_auto_mode_llm,
|
||||
is_auto_mode,
|
||||
)
|
||||
|
||||
|
|
@ -92,14 +90,6 @@ class SanitizedChatLiteLLM(ChatLiteLLM):
|
|||
yield chunk
|
||||
|
||||
|
||||
# Re-exported under the historical name ``PROVIDER_MAP``. Source of truth lives
|
||||
# in provider_capabilities so the YAML loader can resolve prefixes during
|
||||
# app.config init without importing the agent/tools tree.
|
||||
from app.services.provider_capabilities import ( # noqa: E402
|
||||
_PROVIDER_PREFIX_MAP as PROVIDER_MAP,
|
||||
)
|
||||
|
||||
|
||||
def _attach_model_profile(llm: ChatLiteLLM, model_string: str) -> None:
|
||||
"""Attach a ``profile`` dict to ChatLiteLLM with model context metadata."""
|
||||
try:
|
||||
|
|
@ -122,7 +112,8 @@ class AgentConfig:
|
|||
Complete configuration for the SurfSense agent.
|
||||
|
||||
This combines LLM settings with prompt configuration from NewLLMConfig.
|
||||
Supports Auto mode (ID 0) which uses LiteLLM Router for load balancing.
|
||||
Supports Auto mode metadata (ID 0). Runtime callers must resolve Auto to
|
||||
a concrete global or BYOK model before constructing ChatLiteLLM.
|
||||
"""
|
||||
|
||||
# LLM Model Settings
|
||||
|
|
@ -219,7 +210,7 @@ class AgentConfig:
|
|||
# BYOK rows have no curated flag; ask LiteLLM (default-allow on
|
||||
# unknown). The streaming safety net still blocks explicit text-only.
|
||||
supports_image_input=derive_supports_image_input(
|
||||
provider=provider_value,
|
||||
litellm_provider=provider_value.lower(),
|
||||
model_name=config.model_name,
|
||||
base_model=base_model,
|
||||
custom_provider=config.custom_provider,
|
||||
|
|
@ -238,7 +229,7 @@ class AgentConfig:
|
|||
|
||||
system_instructions = yaml_config.get("system_instructions", "")
|
||||
|
||||
provider = yaml_config.get("provider", "").upper()
|
||||
provider = yaml_config.get("litellm_provider", "")
|
||||
model_name = yaml_config.get("model_name", "")
|
||||
custom_provider = yaml_config.get("custom_provider")
|
||||
litellm_params = yaml_config.get("litellm_params") or {}
|
||||
|
|
@ -254,7 +245,7 @@ class AgentConfig:
|
|||
supports_image_input = bool(yaml_config.get("supports_image_input"))
|
||||
else:
|
||||
supports_image_input = derive_supports_image_input(
|
||||
provider=provider,
|
||||
litellm_provider=provider,
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
custom_provider=custom_provider,
|
||||
|
|
@ -383,9 +374,6 @@ async def load_agent_config(
|
|||
) -> "AgentConfig | None":
|
||||
"""Main config loader: id 0 -> Auto mode; negative -> YAML; positive -> DB."""
|
||||
if is_auto_mode(config_id):
|
||||
if not LLMRouterService.is_initialized():
|
||||
print("Error: Auto mode requested but LLM Router not initialized")
|
||||
return None
|
||||
return AgentConfig.from_auto_mode()
|
||||
|
||||
if config_id < 0:
|
||||
|
|
@ -408,9 +396,8 @@ def create_chat_litellm_from_config(llm_config: dict) -> ChatLiteLLM | None:
|
|||
if llm_config.get("custom_provider"):
|
||||
model_string = f"{llm_config['custom_provider']}/{llm_config['model_name']}"
|
||||
else:
|
||||
provider = llm_config.get("provider", "").upper()
|
||||
provider_prefix = PROVIDER_MAP.get(provider, provider.lower())
|
||||
model_string = f"{provider_prefix}/{llm_config['model_name']}"
|
||||
litellm_provider = llm_config.get("litellm_provider", "openai")
|
||||
model_string = f"{litellm_provider}/{llm_config['model_name']}"
|
||||
|
||||
litellm_kwargs = {
|
||||
"model": model_string,
|
||||
|
|
@ -433,29 +420,15 @@ def create_chat_litellm_from_config(llm_config: dict) -> ChatLiteLLM | None:
|
|||
def create_chat_litellm_from_agent_config(
|
||||
agent_config: AgentConfig,
|
||||
) -> ChatLiteLLM | ChatLiteLLMRouter | None:
|
||||
"""Create a ChatLiteLLM (or, for Auto mode, a load-balancing router) from config."""
|
||||
"""Create a ChatLiteLLM from an already resolved concrete model config."""
|
||||
if agent_config.is_auto_mode:
|
||||
if not LLMRouterService.is_initialized():
|
||||
print("Error: Auto mode requested but LLM Router not initialized")
|
||||
return None
|
||||
try:
|
||||
router_llm = get_auto_mode_llm()
|
||||
if router_llm is not None:
|
||||
# Universal injection points only: auto-mode fans out across
|
||||
# providers, so provider-specific kwargs have no known target.
|
||||
apply_litellm_prompt_caching(router_llm, agent_config=agent_config)
|
||||
return router_llm
|
||||
except Exception as e:
|
||||
print(f"Error creating ChatLiteLLMRouter: {e}")
|
||||
return None
|
||||
print("Error: Auto mode must be resolved to a concrete model before LLM creation")
|
||||
return None
|
||||
|
||||
if agent_config.custom_provider:
|
||||
model_string = f"{agent_config.custom_provider}/{agent_config.model_name}"
|
||||
else:
|
||||
provider_prefix = PROVIDER_MAP.get(
|
||||
agent_config.provider, agent_config.provider.lower()
|
||||
)
|
||||
model_string = f"{provider_prefix}/{agent_config.model_name}"
|
||||
model_string = f"{agent_config.provider}/{agent_config.model_name}"
|
||||
|
||||
litellm_kwargs = {
|
||||
"model": model_string,
|
||||
|
|
|
|||
|
|
@ -132,7 +132,7 @@ async def list_anonymous_models():
|
|||
id=cfg.get("id", 0),
|
||||
name=cfg.get("name", ""),
|
||||
description=cfg.get("description"),
|
||||
provider=cfg.get("provider", ""),
|
||||
provider=cfg.get("litellm_provider", ""),
|
||||
model_name=cfg.get("model_name", ""),
|
||||
billing_tier=cfg.get("billing_tier", "free"),
|
||||
is_premium=cfg.get("billing_tier", "free") == "premium",
|
||||
|
|
@ -161,7 +161,7 @@ async def get_anonymous_model(slug: str):
|
|||
id=cfg.get("id", 0),
|
||||
name=cfg.get("name", ""),
|
||||
description=cfg.get("description"),
|
||||
provider=cfg.get("provider", ""),
|
||||
provider=cfg.get("litellm_provider", ""),
|
||||
model_name=cfg.get("model_name", ""),
|
||||
billing_tier=cfg.get("billing_tier", "free"),
|
||||
is_premium=cfg.get("billing_tier", "free") == "premium",
|
||||
|
|
|
|||
|
|
@ -96,7 +96,7 @@ async def get_global_vision_llm_configs(
|
|||
"id": cfg.get("id"),
|
||||
"name": cfg.get("name"),
|
||||
"description": cfg.get("description"),
|
||||
"provider": cfg.get("provider"),
|
||||
"provider": cfg.get("litellm_provider"),
|
||||
"custom_provider": cfg.get("custom_provider"),
|
||||
"model_name": cfg.get("model_name"),
|
||||
"api_base": cfg.get("api_base") or None,
|
||||
|
|
|
|||
|
|
@ -23,9 +23,10 @@ from uuid import UUID
|
|||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from app.config import config
|
||||
from app.db import NewChatThread
|
||||
from app.db import Connection, Model, NewChatThread
|
||||
from app.services.quality_score import _QUALITY_TOP_K
|
||||
from app.services.token_quota_service import TokenQuotaService
|
||||
|
||||
|
|
@ -61,11 +62,20 @@ def _is_usable_global_config(cfg: dict) -> bool:
|
|||
return bool(
|
||||
cfg.get("id") is not None
|
||||
and cfg.get("model_name")
|
||||
and cfg.get("provider")
|
||||
and cfg.get("litellm_provider")
|
||||
and cfg.get("api_key")
|
||||
)
|
||||
|
||||
|
||||
def _has_capability(model: dict | Model, capability: str) -> bool:
|
||||
caps = (
|
||||
model.get("capabilities", {})
|
||||
if isinstance(model, dict)
|
||||
else model.capabilities or {}
|
||||
)
|
||||
return bool(caps.get(capability))
|
||||
|
||||
|
||||
def _prune_runtime_cooldowns(now_ts: float | None = None) -> None:
|
||||
now = time.time() if now_ts is None else now_ts
|
||||
stale = [cid for cid, until in _runtime_cooldown_until.items() if until <= now]
|
||||
|
|
@ -186,15 +196,19 @@ def _cfg_supports_image_input(cfg: dict) -> bool:
|
|||
else None
|
||||
)
|
||||
return derive_supports_image_input(
|
||||
provider=cfg.get("provider"),
|
||||
litellm_provider=cfg.get("litellm_provider"),
|
||||
model_name=cfg.get("model_name"),
|
||||
base_model=base_model,
|
||||
custom_provider=cfg.get("custom_provider"),
|
||||
)
|
||||
|
||||
|
||||
def _global_candidates(*, requires_image_input: bool = False) -> list[dict]:
|
||||
"""Return Auto-eligible global cfgs.
|
||||
def _global_candidates(
|
||||
*,
|
||||
capability: str = "chat",
|
||||
requires_image_input: bool = False,
|
||||
) -> list[dict]:
|
||||
"""Return Auto-eligible global virtual models.
|
||||
|
||||
Drops cfgs flagged ``health_gated`` (best non-null OpenRouter uptime
|
||||
below ``_HEALTH_GATE_UPTIME_PCT``) so chronically broken providers
|
||||
|
|
@ -205,17 +219,135 @@ def _global_candidates(*, requires_image_input: bool = False) -> list[dict]:
|
|||
filters out configs whose ``supports_image_input`` resolves to False
|
||||
so a text-only deployment can't be pinned for an image request.
|
||||
"""
|
||||
candidates = [
|
||||
cfg
|
||||
connection_by_id = {
|
||||
int(conn.get("id")): conn
|
||||
for conn in config.GLOBAL_CONNECTIONS
|
||||
if conn.get("id") is not None
|
||||
}
|
||||
config_by_model_name = {
|
||||
cfg.get("model_name"): cfg
|
||||
for cfg in config.GLOBAL_LLM_CONFIGS
|
||||
if _is_usable_global_config(cfg)
|
||||
and not cfg.get("health_gated")
|
||||
and not _is_runtime_cooled_down(int(cfg.get("id", 0)))
|
||||
and (not requires_image_input or _cfg_supports_image_input(cfg))
|
||||
]
|
||||
}
|
||||
candidates: list[dict] = []
|
||||
for model in config.GLOBAL_MODELS:
|
||||
model_id = int(model.get("id", 0))
|
||||
if model_id >= 0 or _is_runtime_cooled_down(model_id):
|
||||
continue
|
||||
if not _has_capability(model, capability):
|
||||
continue
|
||||
cfg = config_by_model_name.get(model.get("model_id")) or {}
|
||||
if cfg.get("health_gated"):
|
||||
continue
|
||||
if requires_image_input and not _has_capability(model, "vision"):
|
||||
continue
|
||||
if requires_image_input and cfg and not _cfg_supports_image_input(cfg):
|
||||
continue
|
||||
connection = connection_by_id.get(int(model.get("connection_id", 0)))
|
||||
if not connection:
|
||||
continue
|
||||
catalog = model.get("catalog") or {}
|
||||
candidates.append(
|
||||
{
|
||||
"id": model_id,
|
||||
"model_id": model.get("model_id"),
|
||||
"source": "global",
|
||||
"connection": connection,
|
||||
"capabilities": model.get("capabilities") or {},
|
||||
"billing_tier": model.get("billing_tier", "free"),
|
||||
"litellm_provider": connection.get("litellm_provider"),
|
||||
"model_name": model.get("model_id"),
|
||||
"auto_pin_tier": catalog.get("auto_pin_tier")
|
||||
or cfg.get("auto_pin_tier")
|
||||
or "A",
|
||||
"quality_score": catalog.get("quality_score")
|
||||
or cfg.get("quality_score")
|
||||
or cfg.get("quality_score_static")
|
||||
or 50,
|
||||
}
|
||||
)
|
||||
return sorted(candidates, key=lambda c: int(c.get("id", 0)))
|
||||
|
||||
|
||||
async def _db_candidates(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
search_space_id: int,
|
||||
user_id: str | UUID | None,
|
||||
capability: str,
|
||||
requires_image_input: bool = False,
|
||||
) -> list[dict]:
|
||||
parsed_user_id = _to_uuid(user_id)
|
||||
stmt = (
|
||||
select(Model)
|
||||
.options(selectinload(Model.connection))
|
||||
.join(Connection, Model.connection_id == Connection.id)
|
||||
.where(Model.enabled.is_(True), Connection.enabled.is_(True))
|
||||
)
|
||||
result = await session.execute(stmt)
|
||||
candidates: list[dict] = []
|
||||
for model in result.scalars().all():
|
||||
conn = model.connection
|
||||
if not conn:
|
||||
continue
|
||||
if conn.search_space_id is not None and conn.search_space_id != search_space_id:
|
||||
continue
|
||||
if conn.user_id is not None and parsed_user_id is not None and conn.user_id != parsed_user_id:
|
||||
continue
|
||||
if conn.user_id is not None and parsed_user_id is None:
|
||||
continue
|
||||
if not _has_capability(model, capability):
|
||||
continue
|
||||
if requires_image_input and not _has_capability(model, "vision"):
|
||||
continue
|
||||
model_id = int(model.id)
|
||||
if _is_runtime_cooled_down(model_id):
|
||||
continue
|
||||
catalog = model.catalog or {}
|
||||
candidates.append(
|
||||
{
|
||||
"id": model_id,
|
||||
"model_id": model.model_id,
|
||||
"source": "db",
|
||||
"connection": conn,
|
||||
"capabilities": model.capabilities or {},
|
||||
"billing_tier": "byok",
|
||||
"litellm_provider": conn.litellm_provider,
|
||||
"model_name": model.model_id,
|
||||
"auto_pin_tier": catalog.get("auto_pin_tier") or "BYOK",
|
||||
"quality_score": catalog.get("quality_score") or 75,
|
||||
}
|
||||
)
|
||||
return sorted(candidates, key=lambda c: int(c.get("id", 0)))
|
||||
|
||||
|
||||
async def auto_model_candidates(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
search_space_id: int,
|
||||
user_id: str | UUID | None,
|
||||
capability: str,
|
||||
requires_image_input: bool = False,
|
||||
exclude_model_ids: set[int] | None = None,
|
||||
) -> list[dict]:
|
||||
excluded_ids = {int(mid) for mid in (exclude_model_ids or set())}
|
||||
db_candidates = await _db_candidates(
|
||||
session,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
capability=capability,
|
||||
requires_image_input=requires_image_input,
|
||||
)
|
||||
candidates = [
|
||||
*_global_candidates(
|
||||
capability=capability,
|
||||
requires_image_input=requires_image_input,
|
||||
),
|
||||
*db_candidates,
|
||||
]
|
||||
return [c for c in candidates if int(c.get("id", 0)) not in excluded_ids]
|
||||
|
||||
|
||||
def _tier_of(cfg: dict) -> str:
|
||||
return str(cfg.get("billing_tier", "free")).lower()
|
||||
|
||||
|
|
@ -223,8 +355,9 @@ def _tier_of(cfg: dict) -> str:
|
|||
def _is_preferred_premium_auto_config(cfg: dict) -> bool:
|
||||
"""Return True for the operator-preferred premium Auto model."""
|
||||
return (
|
||||
_tier_of(cfg) == "premium"
|
||||
and str(cfg.get("provider", "")).upper() == "AZURE_OPENAI"
|
||||
cfg.get("source") == "global"
|
||||
and _tier_of(cfg) == "premium"
|
||||
and str(cfg.get("litellm_provider", "")).lower() == "azure"
|
||||
and str(cfg.get("model_name", "")).lower() == "gpt-5.4"
|
||||
)
|
||||
|
||||
|
|
@ -251,6 +384,11 @@ def _select_pin(eligible: list[dict], thread_id: int) -> tuple[dict, int]:
|
|||
return top_k[idx], len(top_k)
|
||||
|
||||
|
||||
def choose_auto_model_candidate(candidates: list[dict], seed_id: int) -> dict:
|
||||
selected, _ = _select_pin(candidates, seed_id)
|
||||
return selected
|
||||
|
||||
|
||||
def _to_uuid(user_id: str | UUID | None) -> UUID | None:
|
||||
if user_id is None:
|
||||
return None
|
||||
|
|
@ -326,20 +464,23 @@ async def resolve_or_get_pinned_llm_config_id(
|
|||
)
|
||||
|
||||
excluded_ids = {int(cid) for cid in (exclude_config_ids or set())}
|
||||
candidates = [
|
||||
c
|
||||
for c in _global_candidates(requires_image_input=requires_image_input)
|
||||
if int(c.get("id", 0)) not in excluded_ids
|
||||
]
|
||||
candidates = await auto_model_candidates(
|
||||
session,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
capability="chat",
|
||||
requires_image_input=requires_image_input,
|
||||
exclude_model_ids=excluded_ids,
|
||||
)
|
||||
if not candidates:
|
||||
if requires_image_input:
|
||||
# Distinguish the "no vision-capable cfg" case from generic
|
||||
# "no usable cfg" so the streaming task can map this to the
|
||||
# MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT SSE error.
|
||||
raise ValueError(
|
||||
"No vision-capable global LLM configs are available for Auto mode"
|
||||
"No vision-capable LLM models are available for Auto mode"
|
||||
)
|
||||
raise ValueError("No usable global LLM configs are available for Auto mode")
|
||||
raise ValueError("No usable LLM models are available for Auto mode")
|
||||
candidate_by_id = {int(c["id"]): c for c in candidates}
|
||||
|
||||
# Reuse an existing valid pin without re-checking current quota (no silent
|
||||
|
|
@ -379,24 +520,13 @@ async def resolve_or_get_pinned_llm_config_id(
|
|||
# log that explicitly so operators can correlate the re-pin with
|
||||
# the user's image attachment instead of suspecting a cooldown.
|
||||
if requires_image_input:
|
||||
try:
|
||||
pinned_global = next(
|
||||
c
|
||||
for c in config.GLOBAL_LLM_CONFIGS
|
||||
if int(c.get("id", 0)) == int(pinned_id)
|
||||
)
|
||||
except StopIteration:
|
||||
pinned_global = None
|
||||
if pinned_global is not None and not _cfg_supports_image_input(
|
||||
pinned_global
|
||||
):
|
||||
logger.info(
|
||||
"auto_pin_repinned_for_image thread_id=%s search_space_id=%s "
|
||||
"previous_config_id=%s",
|
||||
thread_id,
|
||||
search_space_id,
|
||||
pinned_id,
|
||||
)
|
||||
logger.info(
|
||||
"auto_pin_repinned_for_image thread_id=%s search_space_id=%s "
|
||||
"previous_config_id=%s",
|
||||
thread_id,
|
||||
search_space_id,
|
||||
pinned_id,
|
||||
)
|
||||
logger.info(
|
||||
"auto_pin_invalid thread_id=%s search_space_id=%s pinned_config_id=%s",
|
||||
thread_id,
|
||||
|
|
|
|||
|
|
@ -30,11 +30,7 @@ from litellm.exceptions import (
|
|||
)
|
||||
from pydantic import Field
|
||||
|
||||
from app.services.model_resolver import (
|
||||
NATIVE_PROVIDER_PREFIX,
|
||||
native_connection_from_config,
|
||||
to_litellm,
|
||||
)
|
||||
from app.services.model_resolver import native_connection_from_config, to_litellm
|
||||
from app.utils.perf import get_perf_logger
|
||||
|
||||
litellm.json_logs = False
|
||||
|
|
@ -101,10 +97,6 @@ def _sanitize_content(content: Any) -> Any:
|
|||
# Special ID for Auto mode - uses router for load balancing
|
||||
AUTO_MODE_ID = 0
|
||||
|
||||
# Historical export kept for callers that still import PROVIDER_MAP.
|
||||
PROVIDER_MAP = NATIVE_PROVIDER_PREFIX
|
||||
|
||||
|
||||
class LLMRouterService:
|
||||
"""
|
||||
Singleton service for managing LiteLLM Router.
|
||||
|
|
|
|||
|
|
@ -10,13 +10,11 @@ from sqlalchemy.orm import selectinload
|
|||
|
||||
from app.config import config
|
||||
from app.db import Model, SearchSpace
|
||||
from app.services.llm_router_service import (
|
||||
AUTO_MODE_ID,
|
||||
ChatLiteLLMRouter,
|
||||
LLMRouterService,
|
||||
get_auto_mode_llm,
|
||||
is_auto_mode,
|
||||
from app.services.auto_model_pin_service import (
|
||||
auto_model_candidates,
|
||||
choose_auto_model_candidate,
|
||||
)
|
||||
from app.services.llm_router_service import AUTO_MODE_ID, ChatLiteLLMRouter, is_auto_mode
|
||||
from app.services.model_resolver import native_connection_from_config, to_litellm
|
||||
from app.services.token_tracking_service import token_tracker
|
||||
|
||||
|
|
@ -78,7 +76,7 @@ def _legacy_config_connection(
|
|||
api_version: str | None = None,
|
||||
) -> tuple[str, dict]:
|
||||
cfg = {
|
||||
"provider": provider,
|
||||
"litellm_provider": provider.lower(),
|
||||
"model_name": model_name,
|
||||
"api_key": api_key,
|
||||
"api_base": api_base,
|
||||
|
|
@ -325,23 +323,21 @@ async def get_search_space_llm_instance(
|
|||
logger.error(f"No {role} LLM configured for search space {search_space_id}")
|
||||
return None
|
||||
|
||||
# Check for Auto mode (ID 0) - use router for load balancing
|
||||
# Auto mode resolves to one concrete global or BYOK model from the
|
||||
# unified model-connections catalog.
|
||||
if is_auto_mode(llm_config_id):
|
||||
if not LLMRouterService.is_initialized():
|
||||
logger.error(
|
||||
"Auto mode requested but LLM Router not initialized. "
|
||||
"Ensure global_llm_config.yaml exists with valid configs."
|
||||
)
|
||||
return None
|
||||
|
||||
try:
|
||||
logger.debug(
|
||||
f"Using Auto mode (LLM Router) for search space {search_space_id}, role {role}"
|
||||
)
|
||||
return get_auto_mode_llm(streaming=not disable_streaming)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create ChatLiteLLMRouter: {e}")
|
||||
candidates = await auto_model_candidates(
|
||||
session,
|
||||
search_space_id=search_space_id,
|
||||
user_id=search_space.user_id,
|
||||
capability="chat",
|
||||
)
|
||||
if not candidates:
|
||||
logger.error("No chat-capable models available for Auto mode")
|
||||
return None
|
||||
llm_config_id = int(
|
||||
choose_auto_model_candidate(candidates, search_space_id)["id"]
|
||||
)
|
||||
|
||||
# Check if this is a global virtual model (negative ID)
|
||||
if llm_config_id < 0:
|
||||
|
|
@ -414,7 +410,7 @@ async def get_vision_llm(
|
|||
"""Get the search space's vision LLM instance for screenshot analysis.
|
||||
|
||||
Resolves from the new connection/model role bindings:
|
||||
- Auto mode (ID 0): VisionLLMRouterService
|
||||
- Auto mode (ID 0): unified global/BYOK model candidate selection
|
||||
- Global (negative ID): virtual GLOBAL models from YAML
|
||||
- DB (positive ID): Model + Connection tables
|
||||
|
||||
|
|
@ -424,10 +420,7 @@ async def get_vision_llm(
|
|||
unwrapped — they don't consume premium credit (issue M).
|
||||
"""
|
||||
from app.services.quota_checked_vision_llm import QuotaCheckedVisionLLM
|
||||
from app.services.vision_llm_router_service import (
|
||||
VisionLLMRouterService,
|
||||
is_vision_auto_mode,
|
||||
)
|
||||
from app.services.vision_llm_router_service import is_vision_auto_mode
|
||||
|
||||
try:
|
||||
result = await session.execute(
|
||||
|
|
@ -476,26 +469,16 @@ async def get_vision_llm(
|
|||
return None
|
||||
|
||||
if is_vision_auto_mode(config_id):
|
||||
if not VisionLLMRouterService.is_initialized():
|
||||
logger.error(
|
||||
"Vision Auto mode requested but Vision LLM Router not initialized"
|
||||
)
|
||||
return None
|
||||
try:
|
||||
# Auto mode is currently treated as free at the wrapper
|
||||
# level — the underlying router can dispatch to either
|
||||
# premium or free YAML configs but routing decisions are
|
||||
# opaque. If/when we want to bill Auto-routed vision
|
||||
# calls we'd need to thread the resolved deployment's
|
||||
# billing_tier back from the router. For now we keep
|
||||
# parity with chat Auto, which also doesn't pre-classify.
|
||||
return ChatLiteLLMRouter(
|
||||
router=VisionLLMRouterService.get_router(),
|
||||
streaming=True,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create vision ChatLiteLLMRouter: {e}")
|
||||
candidates = await auto_model_candidates(
|
||||
session,
|
||||
search_space_id=search_space_id,
|
||||
user_id=owner_user_id,
|
||||
capability="vision",
|
||||
)
|
||||
if not candidates:
|
||||
logger.error("No vision-capable models available for Auto mode")
|
||||
return None
|
||||
config_id = int(choose_auto_model_candidate(candidates, search_space_id)["id"])
|
||||
|
||||
if config_id < 0:
|
||||
global_model = get_global_model(config_id)
|
||||
|
|
|
|||
|
|
@ -154,19 +154,19 @@ def _process_models(raw_models: list[dict]) -> list[dict]:
|
|||
}
|
||||
)
|
||||
|
||||
# 2) Emit for the native provider when we have a mapping
|
||||
native_provider = OPENROUTER_SLUG_TO_PROVIDER.get(provider_slug)
|
||||
if native_provider:
|
||||
# 2) Emit for the direct provider when we have a mapping
|
||||
direct_provider = OPENROUTER_SLUG_TO_PROVIDER.get(provider_slug)
|
||||
if direct_provider:
|
||||
# Google's Gemini API only serves gemini-* models.
|
||||
# Open-source models like gemma-* are NOT available through it.
|
||||
if native_provider == "GOOGLE" and not model_name.startswith("gemini-"):
|
||||
if direct_provider == "GOOGLE" and not model_name.startswith("gemini-"):
|
||||
continue
|
||||
|
||||
processed.append(
|
||||
{
|
||||
"value": model_name,
|
||||
"label": name,
|
||||
"provider": native_provider,
|
||||
"provider": direct_provider,
|
||||
"context_window": context_window,
|
||||
}
|
||||
)
|
||||
|
|
|
|||
|
|
@ -9,53 +9,12 @@ from __future__ import annotations
|
|||
from collections.abc import Mapping
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from app.services.provider_api_base import resolve_api_base
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.db import Connection
|
||||
|
||||
PROTOCOL_OLLAMA = "OLLAMA"
|
||||
PROTOCOL_OPENAI_COMPATIBLE = "OPENAI_COMPATIBLE"
|
||||
PROTOCOL_NATIVE = "NATIVE"
|
||||
|
||||
NATIVE_PROVIDER_PREFIX: dict[str, str] = {
|
||||
"OPENAI": "openai",
|
||||
"ANTHROPIC": "anthropic",
|
||||
"GROQ": "groq",
|
||||
"COHERE": "cohere",
|
||||
"GOOGLE": "gemini",
|
||||
"MISTRAL": "mistral",
|
||||
"AZURE_OPENAI": "azure",
|
||||
"AZURE": "azure",
|
||||
"OPENROUTER": "openrouter",
|
||||
"COMETAPI": "cometapi",
|
||||
"XAI": "xai",
|
||||
"BEDROCK": "bedrock",
|
||||
"AWS_BEDROCK": "bedrock",
|
||||
"VERTEX_AI": "vertex_ai",
|
||||
"TOGETHER_AI": "together_ai",
|
||||
"FIREWORKS_AI": "fireworks_ai",
|
||||
"DEEPSEEK": "openai",
|
||||
"ALIBABA_QWEN": "openai",
|
||||
"MOONSHOT": "openai",
|
||||
"ZHIPU": "openai",
|
||||
"GITHUB_MODELS": "github",
|
||||
"REPLICATE": "replicate",
|
||||
"PERPLEXITY": "perplexity",
|
||||
"ANYSCALE": "anyscale",
|
||||
"DEEPINFRA": "deepinfra",
|
||||
"CEREBRAS": "cerebras",
|
||||
"SAMBANOVA": "sambanova",
|
||||
"AI21": "ai21",
|
||||
"CLOUDFLARE": "cloudflare",
|
||||
"DATABRICKS": "databricks",
|
||||
"HUGGINGFACE": "huggingface",
|
||||
"MINIMAX": "openai",
|
||||
"RECRAFT": "recraft",
|
||||
"XINFERENCE": "xinference",
|
||||
"NSCALE": "nscale",
|
||||
"CUSTOM": "custom",
|
||||
}
|
||||
PROTOCOL_ANTHROPIC = "ANTHROPIC"
|
||||
|
||||
|
||||
def ensure_v1(base_url: str | None) -> str | None:
|
||||
|
|
@ -77,6 +36,23 @@ def _protocol_value(protocol: Any) -> str:
|
|||
return getattr(protocol, "value", str(protocol))
|
||||
|
||||
|
||||
def default_litellm_provider(protocol: Any) -> str:
|
||||
protocol_value = _protocol_value(protocol)
|
||||
defaults = {
|
||||
PROTOCOL_OLLAMA: "ollama_chat",
|
||||
PROTOCOL_ANTHROPIC: "anthropic",
|
||||
PROTOCOL_OPENAI_COMPATIBLE: "openai",
|
||||
}
|
||||
return defaults.get(protocol_value, "openai")
|
||||
|
||||
|
||||
def _execution_api_base(protocol: str, base_url: str | None) -> str | None:
|
||||
del protocol
|
||||
if not base_url:
|
||||
return None
|
||||
return base_url.rstrip("/")
|
||||
|
||||
|
||||
def to_litellm(
|
||||
conn: Connection | Mapping[str, Any],
|
||||
model_id: str,
|
||||
|
|
@ -85,38 +61,19 @@ def to_litellm(
|
|||
protocol = _protocol_value(_conn_value(conn, "protocol"))
|
||||
base_url = _conn_value(conn, "base_url")
|
||||
api_key = _conn_value(conn, "api_key")
|
||||
native_provider = _conn_value(conn, "native_provider")
|
||||
litellm_provider = (
|
||||
_conn_value(conn, "litellm_provider") or default_litellm_provider(protocol)
|
||||
)
|
||||
extra = _conn_value(conn, "extra") or {}
|
||||
|
||||
kwargs: dict[str, Any] = {}
|
||||
if api_key:
|
||||
kwargs["api_key"] = api_key
|
||||
|
||||
if protocol == PROTOCOL_OLLAMA:
|
||||
model_string = f"ollama_chat/{model_id}"
|
||||
if base_url:
|
||||
kwargs["api_base"] = base_url.rstrip("/")
|
||||
elif protocol == PROTOCOL_OPENAI_COMPATIBLE:
|
||||
model_string = f"openai/{model_id}"
|
||||
api_base = ensure_v1(base_url)
|
||||
if api_base:
|
||||
kwargs["api_base"] = api_base
|
||||
else:
|
||||
provider_key = (native_provider or "").upper()
|
||||
prefix = NATIVE_PROVIDER_PREFIX.get(provider_key, provider_key.lower())
|
||||
if prefix == "custom":
|
||||
custom_provider = extra.get("custom_provider") or native_provider
|
||||
model_string = f"{custom_provider}/{model_id}" if custom_provider else model_id
|
||||
else:
|
||||
model_string = f"{prefix}/{model_id}"
|
||||
|
||||
api_base = resolve_api_base(
|
||||
provider=provider_key,
|
||||
provider_prefix=prefix,
|
||||
config_api_base=base_url,
|
||||
)
|
||||
if api_base:
|
||||
kwargs["api_base"] = api_base
|
||||
model_string = f"{litellm_provider}/{model_id}" if litellm_provider else model_id
|
||||
api_base = _execution_api_base(protocol, base_url)
|
||||
if api_base:
|
||||
kwargs["api_base"] = api_base
|
||||
|
||||
if api_version := extra.get("api_version"):
|
||||
kwargs["api_version"] = api_version
|
||||
|
|
@ -126,18 +83,21 @@ def to_litellm(
|
|||
|
||||
|
||||
def native_connection_from_config(config: Mapping[str, Any]) -> dict[str, Any]:
|
||||
"""Build an in-memory NATIVE connection mapping from a legacy/global config."""
|
||||
provider = str(config.get("provider") or config.get("custom_provider") or "CUSTOM")
|
||||
"""Build an in-memory connection mapping from a global config."""
|
||||
protocol = str(config.get("protocol") or PROTOCOL_OPENAI_COMPATIBLE)
|
||||
litellm_provider = str(
|
||||
config.get("litellm_provider")
|
||||
or config.get("custom_provider")
|
||||
or default_litellm_provider(protocol)
|
||||
)
|
||||
extra: dict[str, Any] = {
|
||||
"litellm_params": config.get("litellm_params") or {},
|
||||
}
|
||||
if config.get("api_version"):
|
||||
extra["api_version"] = config.get("api_version")
|
||||
if config.get("custom_provider"):
|
||||
extra["custom_provider"] = config.get("custom_provider")
|
||||
return {
|
||||
"protocol": PROTOCOL_NATIVE,
|
||||
"native_provider": provider,
|
||||
"protocol": protocol,
|
||||
"litellm_provider": litellm_provider,
|
||||
"base_url": config.get("api_base") or None,
|
||||
"api_key": config.get("api_key") or None,
|
||||
"extra": extra,
|
||||
|
|
@ -145,7 +105,7 @@ def native_connection_from_config(config: Mapping[str, Any]) -> dict[str, Any]:
|
|||
|
||||
|
||||
__all__ = [
|
||||
"NATIVE_PROVIDER_PREFIX",
|
||||
"default_litellm_provider",
|
||||
"ensure_v1",
|
||||
"native_connection_from_config",
|
||||
"to_litellm",
|
||||
|
|
|
|||
|
|
@ -3,19 +3,12 @@ from typing import Any
|
|||
|
||||
from litellm import Router
|
||||
|
||||
from app.services.model_resolver import (
|
||||
NATIVE_PROVIDER_PREFIX,
|
||||
native_connection_from_config,
|
||||
to_litellm,
|
||||
)
|
||||
from app.services.model_resolver import native_connection_from_config, to_litellm
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
VISION_AUTO_MODE_ID = 0
|
||||
|
||||
VISION_PROVIDER_MAP = NATIVE_PROVIDER_PREFIX
|
||||
|
||||
|
||||
class VisionLLMRouterService:
|
||||
_instance = None
|
||||
_router: Router | None = None
|
||||
|
|
@ -141,12 +134,11 @@ def is_vision_auto_mode(config_id: int | None) -> bool:
|
|||
|
||||
|
||||
def build_vision_model_string(
|
||||
provider: str, model_name: str, custom_provider: str | None
|
||||
litellm_provider: str, model_name: str, custom_provider: str | None
|
||||
) -> str:
|
||||
if custom_provider:
|
||||
return f"{custom_provider}/{model_name}"
|
||||
prefix = VISION_PROVIDER_MAP.get(provider.upper(), provider.lower())
|
||||
return f"{prefix}/{model_name}"
|
||||
return f"{litellm_provider}/{model_name}"
|
||||
|
||||
|
||||
def get_global_vision_llm_config(config_id: int) -> dict | None:
|
||||
|
|
|
|||
|
|
@ -97,16 +97,16 @@ def _process_vision_models(raw_models: list[dict]) -> list[dict]:
|
|||
}
|
||||
)
|
||||
|
||||
native_provider = OPENROUTER_SLUG_TO_VISION_PROVIDER.get(provider_slug)
|
||||
if native_provider:
|
||||
if native_provider == "GOOGLE" and not model_name.startswith("gemini-"):
|
||||
direct_provider = OPENROUTER_SLUG_TO_VISION_PROVIDER.get(provider_slug)
|
||||
if direct_provider:
|
||||
if direct_provider == "GOOGLE" and not model_name.startswith("gemini-"):
|
||||
continue
|
||||
|
||||
processed.append(
|
||||
{
|
||||
"value": model_name,
|
||||
"label": name,
|
||||
"provider": native_provider,
|
||||
"provider": direct_provider,
|
||||
"context_window": context_window,
|
||||
}
|
||||
)
|
||||
|
|
|
|||
|
|
@ -40,7 +40,7 @@ def check_image_input_capability(
|
|||
else None
|
||||
)
|
||||
if not is_known_text_only_chat_model(
|
||||
provider=agent_config.provider,
|
||||
litellm_provider=agent_config.provider,
|
||||
model_name=agent_config.model_name,
|
||||
base_model=agent_base_model,
|
||||
custom_provider=agent_config.custom_provider,
|
||||
|
|
|
|||
|
|
@ -80,7 +80,6 @@ async def _generate_title(
|
|||
from litellm import acompletion
|
||||
|
||||
from app.services.llm_router_service import LLMRouterService
|
||||
from app.services.provider_api_base import resolve_api_base
|
||||
from app.services.token_tracking_service import _turn_accumulator
|
||||
|
||||
# Excludes this turn's own assistant row (pre-written by
|
||||
|
|
@ -125,26 +124,12 @@ async def _generate_title(
|
|||
router = LLMRouterService.get_router()
|
||||
response = await router.acompletion(model="auto", messages=messages)
|
||||
else:
|
||||
# Apply the same ``api_base`` cascade chat / vision / image-gen
|
||||
# call sites use so we never inherit ``litellm.api_base``
|
||||
# (commonly set by ``AZURE_OPENAI_ENDPOINT``) when the chat
|
||||
# config itself ships an empty ``api_base``. Without this the
|
||||
# title-gen on an OpenRouter chat config would 404 against the
|
||||
# inherited Azure endpoint — see ``provider_api_base`` for the
|
||||
# same bug repro on the image-gen / vision paths.
|
||||
raw_model = getattr(llm, "model", "") or ""
|
||||
provider_prefix = raw_model.split("/", 1)[0] if "/" in raw_model else None
|
||||
provider_value = agent_config.provider if agent_config is not None else None
|
||||
title_api_base = resolve_api_base(
|
||||
provider=provider_value,
|
||||
provider_prefix=provider_prefix,
|
||||
config_api_base=getattr(llm, "api_base", None),
|
||||
)
|
||||
response = await acompletion(
|
||||
model=raw_model,
|
||||
messages=messages,
|
||||
api_key=getattr(llm, "api_key", None),
|
||||
api_base=title_api_base,
|
||||
api_base=getattr(llm, "api_base", None),
|
||||
)
|
||||
|
||||
usage_info = None
|
||||
|
|
|
|||
|
|
@ -1,8 +1,8 @@
|
|||
"""Load an LLM + AgentConfig bundle for a given config id.
|
||||
|
||||
Handles both code paths uniformly:
|
||||
- ``config_id >= 0`` → database-backed ``NewLLMConfig`` row (per-user/per-space).
|
||||
- ``config_id < 0`` → YAML-defined global LLM config (built-in defaults).
|
||||
- ``config_id > 0`` → database-backed model-connection ``Model`` row.
|
||||
- ``config_id < 0`` → virtual global model materialized from YAML/OpenRouter.
|
||||
|
||||
Returns ``(llm, agent_config, error_message)``; on success ``error_message`` is
|
||||
``None``. The caller emits the friendly SSE error frame.
|
||||
|
|
@ -12,15 +12,72 @@ from __future__ import annotations
|
|||
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from app.agents.chat.runtime.llm_config import (
|
||||
AgentConfig,
|
||||
create_chat_litellm_from_agent_config,
|
||||
create_chat_litellm_from_config,
|
||||
load_agent_config,
|
||||
load_global_llm_config_by_id,
|
||||
SanitizedChatLiteLLM,
|
||||
)
|
||||
from app.config import config
|
||||
from app.db import Model, SearchSpace
|
||||
from app.services.model_resolver import to_litellm
|
||||
|
||||
|
||||
def _agent_config_from_resolved(
|
||||
*,
|
||||
config_id: int,
|
||||
config_name: str | None,
|
||||
provider: str,
|
||||
model_name: str,
|
||||
api_key: str | None,
|
||||
api_base: str | None,
|
||||
litellm_params: dict | None,
|
||||
supports_image_input: bool,
|
||||
billing_tier: str = "free",
|
||||
) -> AgentConfig:
|
||||
return AgentConfig(
|
||||
provider=provider,
|
||||
model_name=model_name,
|
||||
api_key=api_key or "",
|
||||
api_base=api_base,
|
||||
custom_provider=None,
|
||||
litellm_params=litellm_params,
|
||||
config_id=config_id,
|
||||
config_name=config_name,
|
||||
is_auto_mode=False,
|
||||
billing_tier=billing_tier,
|
||||
is_premium=billing_tier == "premium",
|
||||
supports_image_input=supports_image_input,
|
||||
)
|
||||
|
||||
|
||||
async def _load_search_space(session: AsyncSession, search_space_id: int) -> SearchSpace | None:
|
||||
result = await session.execute(select(SearchSpace).where(SearchSpace.id == search_space_id))
|
||||
return result.scalars().first()
|
||||
|
||||
|
||||
async def _load_db_model(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
model_id: int,
|
||||
search_space: SearchSpace,
|
||||
) -> Model | None:
|
||||
result = await session.execute(
|
||||
select(Model)
|
||||
.options(selectinload(Model.connection))
|
||||
.where(Model.id == model_id, Model.enabled.is_(True))
|
||||
)
|
||||
model = result.scalars().first()
|
||||
if not model or not model.connection or not model.connection.enabled:
|
||||
return None
|
||||
conn = model.connection
|
||||
if conn.search_space_id is not None and conn.search_space_id != search_space.id:
|
||||
return None
|
||||
if conn.user_id is not None and conn.user_id != search_space.user_id:
|
||||
return None
|
||||
return model
|
||||
|
||||
|
||||
async def load_llm_bundle(
|
||||
|
|
@ -29,29 +86,67 @@ async def load_llm_bundle(
|
|||
config_id: int,
|
||||
search_space_id: int,
|
||||
) -> tuple[Any, AgentConfig | None, str | None]:
|
||||
if config_id >= 0:
|
||||
loaded_agent_config = await load_agent_config(
|
||||
session=session,
|
||||
config_id=config_id,
|
||||
search_space_id=search_space_id,
|
||||
search_space = await _load_search_space(session, search_space_id)
|
||||
if not search_space:
|
||||
return None, None, f"Search space {search_space_id} not found"
|
||||
|
||||
if config_id > 0:
|
||||
model = await _load_db_model(
|
||||
session,
|
||||
model_id=config_id,
|
||||
search_space=search_space,
|
||||
)
|
||||
if not loaded_agent_config:
|
||||
if not model or not (model.capabilities or {}).get("chat"):
|
||||
return (
|
||||
None,
|
||||
None,
|
||||
f"Failed to load NewLLMConfig with id {config_id}",
|
||||
f"Failed to load chat model with id {config_id}",
|
||||
)
|
||||
model_string, litellm_kwargs = to_litellm(model.connection, model.model_id)
|
||||
agent_config = _agent_config_from_resolved(
|
||||
config_id=config_id,
|
||||
config_name=model.display_name or model.model_id,
|
||||
provider=model.connection.litellm_provider or "",
|
||||
model_name=model.model_id,
|
||||
api_key=model.connection.api_key,
|
||||
api_base=model.connection.base_url,
|
||||
litellm_params=(model.connection.extra or {}).get("litellm_params"),
|
||||
supports_image_input=bool((model.capabilities or {}).get("vision")),
|
||||
billing_tier="free",
|
||||
)
|
||||
return (
|
||||
create_chat_litellm_from_agent_config(loaded_agent_config),
|
||||
loaded_agent_config,
|
||||
SanitizedChatLiteLLM(model=model_string, **litellm_kwargs),
|
||||
agent_config,
|
||||
None,
|
||||
)
|
||||
|
||||
loaded_llm_config = load_global_llm_config_by_id(config_id)
|
||||
if not loaded_llm_config:
|
||||
return None, None, f"Failed to load LLM config with id {config_id}"
|
||||
return (
|
||||
create_chat_litellm_from_config(loaded_llm_config),
|
||||
AgentConfig.from_yaml_config(loaded_llm_config),
|
||||
global_model = next((m for m in config.GLOBAL_MODELS if m.get("id") == config_id), None)
|
||||
if not global_model or not (global_model.get("capabilities") or {}).get("chat"):
|
||||
return None, None, f"Failed to load global chat model with id {config_id}"
|
||||
global_connection = next(
|
||||
(
|
||||
c
|
||||
for c in config.GLOBAL_CONNECTIONS
|
||||
if c.get("id") == global_model.get("connection_id")
|
||||
),
|
||||
None,
|
||||
)
|
||||
if not global_connection:
|
||||
return None, None, f"Failed to load global connection for model {config_id}"
|
||||
model_string, litellm_kwargs = to_litellm(global_connection, global_model["model_id"])
|
||||
agent_config = _agent_config_from_resolved(
|
||||
config_id=config_id,
|
||||
config_name=global_model.get("display_name") or global_model.get("model_id"),
|
||||
provider=global_connection.get("litellm_provider") or "",
|
||||
model_name=global_model["model_id"],
|
||||
api_key=global_connection.get("api_key"),
|
||||
api_base=global_connection.get("base_url"),
|
||||
litellm_params=(global_connection.get("extra") or {}).get("litellm_params"),
|
||||
supports_image_input=bool((global_model.get("capabilities") or {}).get("vision")),
|
||||
billing_tier=str(global_model.get("billing_tier", "free")).lower(),
|
||||
)
|
||||
return (
|
||||
SanitizedChatLiteLLM(model=model_string, **litellm_kwargs),
|
||||
agent_config,
|
||||
None,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -75,10 +75,10 @@ async def test_auto_first_turn_pins_one_model(monkeypatch):
|
|||
config,
|
||||
"GLOBAL_LLM_CONFIGS",
|
||||
[
|
||||
{"id": -2, "provider": "OPENAI", "model_name": "gpt-free", "api_key": "k1"},
|
||||
{"id": -2, "litellm_provider": "openai", "model_name": "gpt-free", "api_key": "k1"},
|
||||
{
|
||||
"id": -1,
|
||||
"provider": "OPENAI",
|
||||
"litellm_provider": "openai",
|
||||
"model_name": "gpt-prem",
|
||||
"api_key": "k2",
|
||||
"billing_tier": "premium",
|
||||
|
|
@ -117,7 +117,7 @@ async def test_premium_eligible_auto_prefers_premium_over_free(monkeypatch):
|
|||
[
|
||||
{
|
||||
"id": -2,
|
||||
"provider": "OPENAI",
|
||||
"litellm_provider": "openai",
|
||||
"model_name": "gpt-free",
|
||||
"api_key": "k1",
|
||||
"billing_tier": "free",
|
||||
|
|
@ -125,7 +125,7 @@ async def test_premium_eligible_auto_prefers_premium_over_free(monkeypatch):
|
|||
},
|
||||
{
|
||||
"id": -1,
|
||||
"provider": "OPENAI",
|
||||
"litellm_provider": "openai",
|
||||
"model_name": "gpt-prem",
|
||||
"api_key": "k2",
|
||||
"billing_tier": "premium",
|
||||
|
|
@ -164,7 +164,7 @@ async def test_premium_eligible_auto_prefers_azure_gpt_5_4(monkeypatch):
|
|||
[
|
||||
{
|
||||
"id": -1,
|
||||
"provider": "AZURE_OPENAI",
|
||||
"litellm_provider": "azure",
|
||||
"model_name": "gpt-5.1",
|
||||
"api_key": "k1",
|
||||
"billing_tier": "premium",
|
||||
|
|
@ -173,7 +173,7 @@ async def test_premium_eligible_auto_prefers_azure_gpt_5_4(monkeypatch):
|
|||
},
|
||||
{
|
||||
"id": -2,
|
||||
"provider": "AZURE_OPENAI",
|
||||
"litellm_provider": "azure",
|
||||
"model_name": "gpt-5.4",
|
||||
"api_key": "k2",
|
||||
"billing_tier": "premium",
|
||||
|
|
@ -182,7 +182,7 @@ async def test_premium_eligible_auto_prefers_azure_gpt_5_4(monkeypatch):
|
|||
},
|
||||
{
|
||||
"id": -3,
|
||||
"provider": "OPENROUTER",
|
||||
"litellm_provider": "openrouter",
|
||||
"model_name": "openai/gpt-5.4",
|
||||
"api_key": "k3",
|
||||
"billing_tier": "premium",
|
||||
|
|
@ -222,7 +222,7 @@ async def test_next_turn_reuses_existing_pin(monkeypatch):
|
|||
[
|
||||
{
|
||||
"id": -1,
|
||||
"provider": "OPENAI",
|
||||
"litellm_provider": "openai",
|
||||
"model_name": "gpt-prem",
|
||||
"api_key": "k2",
|
||||
"billing_tier": "premium",
|
||||
|
|
@ -263,7 +263,7 @@ async def test_premium_eligible_auto_can_pin_premium(monkeypatch):
|
|||
[
|
||||
{
|
||||
"id": -1,
|
||||
"provider": "OPENAI",
|
||||
"litellm_provider": "openai",
|
||||
"model_name": "gpt-prem",
|
||||
"api_key": "k2",
|
||||
"billing_tier": "premium",
|
||||
|
|
@ -301,14 +301,14 @@ async def test_premium_ineligible_auto_pins_free_only(monkeypatch):
|
|||
[
|
||||
{
|
||||
"id": -2,
|
||||
"provider": "OPENAI",
|
||||
"litellm_provider": "openai",
|
||||
"model_name": "gpt-free",
|
||||
"api_key": "k1",
|
||||
"billing_tier": "free",
|
||||
},
|
||||
{
|
||||
"id": -1,
|
||||
"provider": "OPENAI",
|
||||
"litellm_provider": "openai",
|
||||
"model_name": "gpt-prem",
|
||||
"api_key": "k2",
|
||||
"billing_tier": "premium",
|
||||
|
|
@ -346,14 +346,14 @@ async def test_pinned_premium_stays_premium_after_quota_exhaustion(monkeypatch):
|
|||
[
|
||||
{
|
||||
"id": -2,
|
||||
"provider": "OPENAI",
|
||||
"litellm_provider": "openai",
|
||||
"model_name": "gpt-free",
|
||||
"api_key": "k1",
|
||||
"billing_tier": "free",
|
||||
},
|
||||
{
|
||||
"id": -1,
|
||||
"provider": "OPENAI",
|
||||
"litellm_provider": "openai",
|
||||
"model_name": "gpt-prem",
|
||||
"api_key": "k2",
|
||||
"billing_tier": "premium",
|
||||
|
|
@ -391,14 +391,14 @@ async def test_force_repin_free_switches_auto_premium_pin_to_free(monkeypatch):
|
|||
[
|
||||
{
|
||||
"id": -2,
|
||||
"provider": "OPENAI",
|
||||
"litellm_provider": "openai",
|
||||
"model_name": "gpt-free",
|
||||
"api_key": "k1",
|
||||
"billing_tier": "free",
|
||||
},
|
||||
{
|
||||
"id": -1,
|
||||
"provider": "OPENAI",
|
||||
"litellm_provider": "openai",
|
||||
"model_name": "gpt-prem",
|
||||
"api_key": "k2",
|
||||
"billing_tier": "premium",
|
||||
|
|
@ -437,7 +437,7 @@ async def test_explicit_user_model_change_clears_pin(monkeypatch):
|
|||
config,
|
||||
"GLOBAL_LLM_CONFIGS",
|
||||
[
|
||||
{"id": -2, "provider": "OPENAI", "model_name": "gpt-free", "api_key": "k1"},
|
||||
{"id": -2, "litellm_provider": "openai", "model_name": "gpt-free", "api_key": "k1"},
|
||||
],
|
||||
)
|
||||
|
||||
|
|
@ -462,7 +462,7 @@ async def test_invalid_pinned_config_repairs_with_new_pin(monkeypatch):
|
|||
config,
|
||||
"GLOBAL_LLM_CONFIGS",
|
||||
[
|
||||
{"id": -2, "provider": "OPENAI", "model_name": "gpt-free", "api_key": "k1"},
|
||||
{"id": -2, "litellm_provider": "openai", "model_name": "gpt-free", "api_key": "k1"},
|
||||
],
|
||||
)
|
||||
|
||||
|
|
@ -504,7 +504,7 @@ async def test_health_gated_config_is_excluded_from_selection(monkeypatch):
|
|||
[
|
||||
{
|
||||
"id": -1,
|
||||
"provider": "OPENROUTER",
|
||||
"litellm_provider": "openrouter",
|
||||
"model_name": "venice/dead-model",
|
||||
"api_key": "k1",
|
||||
"billing_tier": "free",
|
||||
|
|
@ -514,7 +514,7 @@ async def test_health_gated_config_is_excluded_from_selection(monkeypatch):
|
|||
},
|
||||
{
|
||||
"id": -2,
|
||||
"provider": "OPENROUTER",
|
||||
"litellm_provider": "openrouter",
|
||||
"model_name": "google/gemini-flash",
|
||||
"api_key": "k1",
|
||||
"billing_tier": "free",
|
||||
|
|
@ -556,7 +556,7 @@ async def test_tier_a_locks_first_premium_user_skips_or(monkeypatch):
|
|||
[
|
||||
{
|
||||
"id": -1,
|
||||
"provider": "AZURE_OPENAI",
|
||||
"litellm_provider": "azure",
|
||||
"model_name": "gpt-5",
|
||||
"api_key": "k-yaml",
|
||||
"billing_tier": "premium",
|
||||
|
|
@ -566,7 +566,7 @@ async def test_tier_a_locks_first_premium_user_skips_or(monkeypatch):
|
|||
},
|
||||
{
|
||||
"id": -2,
|
||||
"provider": "OPENROUTER",
|
||||
"litellm_provider": "openrouter",
|
||||
"model_name": "openai/gpt-5",
|
||||
"api_key": "k-or",
|
||||
"billing_tier": "premium",
|
||||
|
|
@ -608,7 +608,7 @@ async def test_tier_a_falls_through_to_or_when_a_pool_empty_for_user(monkeypatch
|
|||
[
|
||||
{
|
||||
"id": -1,
|
||||
"provider": "AZURE_OPENAI",
|
||||
"litellm_provider": "azure",
|
||||
"model_name": "gpt-5",
|
||||
"api_key": "k-yaml",
|
||||
"billing_tier": "premium",
|
||||
|
|
@ -618,7 +618,7 @@ async def test_tier_a_falls_through_to_or_when_a_pool_empty_for_user(monkeypatch
|
|||
},
|
||||
{
|
||||
"id": -2,
|
||||
"provider": "OPENROUTER",
|
||||
"litellm_provider": "openrouter",
|
||||
"model_name": "google/gemini-flash:free",
|
||||
"api_key": "k-or",
|
||||
"billing_tier": "free",
|
||||
|
|
@ -656,7 +656,7 @@ async def test_top_k_picks_only_high_score_models(monkeypatch):
|
|||
high_score_cfgs = [
|
||||
{
|
||||
"id": -i,
|
||||
"provider": "AZURE_OPENAI",
|
||||
"litellm_provider": "azure",
|
||||
"model_name": f"gpt-x-{i}",
|
||||
"api_key": "k",
|
||||
"billing_tier": "premium",
|
||||
|
|
@ -668,7 +668,7 @@ async def test_top_k_picks_only_high_score_models(monkeypatch):
|
|||
]
|
||||
low_score_trap = {
|
||||
"id": -99,
|
||||
"provider": "AZURE_OPENAI",
|
||||
"litellm_provider": "azure",
|
||||
"model_name": "tiny-legacy",
|
||||
"api_key": "k",
|
||||
"billing_tier": "premium",
|
||||
|
|
@ -729,7 +729,7 @@ async def test_pin_reuse_survives_health_gating_for_existing_pin(monkeypatch):
|
|||
[
|
||||
{
|
||||
"id": -1,
|
||||
"provider": "OPENROUTER",
|
||||
"litellm_provider": "openrouter",
|
||||
"model_name": "venice/dead-model",
|
||||
"api_key": "k",
|
||||
"billing_tier": "premium",
|
||||
|
|
@ -739,7 +739,7 @@ async def test_pin_reuse_survives_health_gating_for_existing_pin(monkeypatch):
|
|||
},
|
||||
{
|
||||
"id": -2,
|
||||
"provider": "AZURE_OPENAI",
|
||||
"litellm_provider": "azure",
|
||||
"model_name": "gpt-5",
|
||||
"api_key": "k",
|
||||
"billing_tier": "premium",
|
||||
|
|
@ -781,7 +781,7 @@ async def test_pin_reuse_regression_existing_healthy_pin(monkeypatch):
|
|||
[
|
||||
{
|
||||
"id": -1,
|
||||
"provider": "AZURE_OPENAI",
|
||||
"litellm_provider": "azure",
|
||||
"model_name": "gpt-5",
|
||||
"api_key": "k",
|
||||
"billing_tier": "premium",
|
||||
|
|
@ -791,7 +791,7 @@ async def test_pin_reuse_regression_existing_healthy_pin(monkeypatch):
|
|||
},
|
||||
{
|
||||
"id": -2,
|
||||
"provider": "AZURE_OPENAI",
|
||||
"litellm_provider": "azure",
|
||||
"model_name": "gpt-5-pro",
|
||||
"api_key": "k",
|
||||
"billing_tier": "premium",
|
||||
|
|
@ -839,7 +839,7 @@ async def test_runtime_cooled_down_pin_is_not_reused(monkeypatch):
|
|||
[
|
||||
{
|
||||
"id": -1,
|
||||
"provider": "OPENROUTER",
|
||||
"litellm_provider": "openrouter",
|
||||
"model_name": "google/gemma-4-26b-a4b-it:free",
|
||||
"api_key": "k",
|
||||
"billing_tier": "free",
|
||||
|
|
@ -849,7 +849,7 @@ async def test_runtime_cooled_down_pin_is_not_reused(monkeypatch):
|
|||
},
|
||||
{
|
||||
"id": -2,
|
||||
"provider": "OPENROUTER",
|
||||
"litellm_provider": "openrouter",
|
||||
"model_name": "google/gemini-2.5-flash:free",
|
||||
"api_key": "k",
|
||||
"billing_tier": "free",
|
||||
|
|
@ -892,7 +892,7 @@ async def test_clearing_runtime_cooldown_restores_pin_reuse(monkeypatch):
|
|||
[
|
||||
{
|
||||
"id": -1,
|
||||
"provider": "OPENROUTER",
|
||||
"litellm_provider": "openrouter",
|
||||
"model_name": "google/gemma-4-26b-a4b-it:free",
|
||||
"api_key": "k",
|
||||
"billing_tier": "free",
|
||||
|
|
@ -937,7 +937,7 @@ async def test_auto_pin_repin_excludes_previous_config_on_runtime_retry(monkeypa
|
|||
[
|
||||
{
|
||||
"id": -1,
|
||||
"provider": "OPENROUTER",
|
||||
"litellm_provider": "openrouter",
|
||||
"model_name": "google/gemma-4-26b-a4b-it:free",
|
||||
"api_key": "k",
|
||||
"billing_tier": "free",
|
||||
|
|
@ -947,7 +947,7 @@ async def test_auto_pin_repin_excludes_previous_config_on_runtime_retry(monkeypa
|
|||
},
|
||||
{
|
||||
"id": -2,
|
||||
"provider": "OPENROUTER",
|
||||
"litellm_provider": "openrouter",
|
||||
"model_name": "google/gemini-2.5-flash:free",
|
||||
"api_key": "k",
|
||||
"billing_tier": "free",
|
||||
|
|
|
|||
|
|
@ -74,7 +74,7 @@ def _thread(*, pinned: int | None = None):
|
|||
def _vision_cfg(id_: int, *, tier: str = "free", quality: int = 80) -> dict:
|
||||
return {
|
||||
"id": id_,
|
||||
"provider": "OPENAI",
|
||||
"litellm_provider": "openai",
|
||||
"model_name": f"vision-{id_}",
|
||||
"api_key": "k",
|
||||
"billing_tier": tier,
|
||||
|
|
@ -87,7 +87,7 @@ def _vision_cfg(id_: int, *, tier: str = "free", quality: int = 80) -> dict:
|
|||
def _text_only_cfg(id_: int, *, tier: str = "free", quality: int = 90) -> dict:
|
||||
return {
|
||||
"id": id_,
|
||||
"provider": "OPENAI",
|
||||
"litellm_provider": "openai",
|
||||
"model_name": f"text-{id_}",
|
||||
"api_key": "k",
|
||||
"billing_tier": tier,
|
||||
|
|
@ -261,7 +261,7 @@ async def test_image_turn_unannotated_cfg_resolves_via_helper(monkeypatch):
|
|||
session = _FakeSession(_thread())
|
||||
cfg_unannotated_vision = {
|
||||
"id": -2,
|
||||
"provider": "OPENAI",
|
||||
"litellm_provider": "openai",
|
||||
"model_name": "gpt-4o", # known vision model in LiteLLM map
|
||||
"api_key": "k",
|
||||
"billing_tier": "free",
|
||||
|
|
|
|||
|
|
@ -25,10 +25,10 @@ def _fake_yaml_config(
|
|||
return {
|
||||
"id": id,
|
||||
"name": f"yaml-{id}",
|
||||
"provider": "OPENAI",
|
||||
"litellm_provider": "openai",
|
||||
"model_name": model_name,
|
||||
"api_key": "sk-test",
|
||||
"api_base": "",
|
||||
"api_base": "https://api.openai.com/v1",
|
||||
"billing_tier": billing_tier,
|
||||
"rpm": 100,
|
||||
"tpm": 100_000,
|
||||
|
|
@ -54,10 +54,10 @@ def _fake_openrouter_config(
|
|||
return {
|
||||
"id": id,
|
||||
"name": f"or-{id}",
|
||||
"provider": "OPENROUTER",
|
||||
"litellm_provider": "openrouter",
|
||||
"model_name": model_name,
|
||||
"api_key": "sk-or-test",
|
||||
"api_base": "",
|
||||
"api_base": "https://openrouter.ai/api/v1",
|
||||
"billing_tier": billing_tier,
|
||||
"rpm": 20 if billing_tier == "free" else 200,
|
||||
"tpm": 100_000 if billing_tier == "free" else 1_000_000,
|
||||
|
|
|
|||
|
|
@ -25,7 +25,7 @@ def _or_cfg(
|
|||
) -> dict:
|
||||
return {
|
||||
"id": cid,
|
||||
"provider": "OPENROUTER",
|
||||
"litellm_provider": "openrouter",
|
||||
"model_name": model_name,
|
||||
"billing_tier": tier,
|
||||
"auto_pin_tier": "B" if tier == "premium" else "C",
|
||||
|
|
@ -144,7 +144,7 @@ async def test_enrich_health_only_touches_or_provider(monkeypatch):
|
|||
"""YAML cfgs that aren't OPENROUTER must be skipped entirely."""
|
||||
yaml_cfg = {
|
||||
"id": -1,
|
||||
"provider": "AZURE_OPENAI",
|
||||
"litellm_provider": "azure",
|
||||
"model_name": "gpt-5",
|
||||
"billing_tier": "premium",
|
||||
"auto_pin_tier": "A",
|
||||
|
|
@ -313,7 +313,7 @@ async def test_enrich_health_no_or_cfgs_is_noop(monkeypatch):
|
|||
"""When the catalogue has no OR cfgs at all, no HTTP calls fire."""
|
||||
yaml_cfg: dict[str, Any] = {
|
||||
"id": -1,
|
||||
"provider": "AZURE_OPENAI",
|
||||
"litellm_provider": "azure",
|
||||
"model_name": "gpt-5",
|
||||
"billing_tier": "premium",
|
||||
}
|
||||
|
|
|
|||
|
|
@ -35,7 +35,7 @@ def test_safety_net_does_not_fire_for_azure_gpt_4o():
|
|||
it text-only."""
|
||||
assert (
|
||||
is_known_text_only_chat_model(
|
||||
provider="AZURE_OPENAI",
|
||||
litellm_provider="azure",
|
||||
model_name="my-azure-deployment",
|
||||
base_model="gpt-4o",
|
||||
)
|
||||
|
|
@ -49,7 +49,7 @@ def test_safety_net_does_not_fire_for_unknown_model():
|
|||
LiteLLM doesn't know about must flow through to the provider."""
|
||||
assert (
|
||||
is_known_text_only_chat_model(
|
||||
provider="CUSTOM",
|
||||
litellm_provider="custom",
|
||||
custom_provider="brand_new_proxy",
|
||||
model_name="brand-new-model-x9",
|
||||
)
|
||||
|
|
@ -69,7 +69,7 @@ def test_safety_net_does_not_fire_when_lookup_raises(monkeypatch):
|
|||
|
||||
assert (
|
||||
is_known_text_only_chat_model(
|
||||
provider="OPENAI",
|
||||
litellm_provider="openai",
|
||||
model_name="gpt-4o",
|
||||
)
|
||||
is False
|
||||
|
|
@ -88,7 +88,7 @@ def test_safety_net_fires_only_on_explicit_false(monkeypatch):
|
|||
monkeypatch.setattr(pc.litellm, "get_model_info", _info_explicit_false)
|
||||
assert (
|
||||
is_known_text_only_chat_model(
|
||||
provider="OPENAI",
|
||||
litellm_provider="openai",
|
||||
model_name="text-only-stub",
|
||||
)
|
||||
is True
|
||||
|
|
@ -100,7 +100,7 @@ def test_safety_net_fires_only_on_explicit_false(monkeypatch):
|
|||
monkeypatch.setattr(pc.litellm, "get_model_info", _info_true)
|
||||
assert (
|
||||
is_known_text_only_chat_model(
|
||||
provider="OPENAI",
|
||||
litellm_provider="openai",
|
||||
model_name="vision-stub",
|
||||
)
|
||||
is False
|
||||
|
|
@ -112,7 +112,7 @@ def test_safety_net_fires_only_on_explicit_false(monkeypatch):
|
|||
monkeypatch.setattr(pc.litellm, "get_model_info", _info_missing)
|
||||
assert (
|
||||
is_known_text_only_chat_model(
|
||||
provider="OPENAI",
|
||||
litellm_provider="openai",
|
||||
model_name="missing-key-stub",
|
||||
)
|
||||
is False
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue