refactor(model-connections): move backend model connections to provider capabilities

This commit is contained in:
Anish Sarkar 2026-06-12 02:17:22 +05:30
parent 3089dd4cb6
commit 5d5d574550
31 changed files with 772 additions and 476 deletions

View file

@ -17,13 +17,6 @@ branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
connection_protocol = postgresql.ENUM(
"OLLAMA",
"OPENAI_COMPATIBLE",
"ANTHROPIC",
name="connectionprotocol",
create_type=False,
)
connection_scope = postgresql.ENUM(
"GLOBAL",
"SEARCH_SPACE",
@ -73,36 +66,67 @@ def _add_searchspace_column_if_missing(column_name: str) -> None:
op.add_column("searchspaces", sa.Column(column_name, sa.Integer(), nullable=True))
def _drop_column_if_exists(table_name: str, column_name: str) -> None:
if _column_exists(table_name, column_name):
op.drop_column(table_name, column_name)
def _drop_index_if_exists(table_name: str, index_name: str) -> None:
if _index_exists(table_name, index_name):
op.drop_index(index_name, table_name=table_name)
def upgrade() -> None:
bind = op.get_bind()
connection_protocol.create(bind, checkfirst=True)
op.execute("ALTER TYPE connectionprotocol ADD VALUE IF NOT EXISTS 'ANTHROPIC'")
connection_scope.create(bind, checkfirst=True)
model_source.create(bind, checkfirst=True)
if _table_exists("connections"):
if _column_exists("connections", "native_provider") and not _column_exists(
"connections", "litellm_provider"
if _column_exists("connections", "litellm_provider") and not _column_exists(
"connections", "provider"
):
op.alter_column(
"connections",
"litellm_provider",
new_column_name="provider",
existing_type=sa.String(length=100),
existing_nullable=True,
)
op.alter_column(
"connections",
"provider",
existing_type=sa.String(length=100),
nullable=False,
)
elif _column_exists("connections", "native_provider") and not _column_exists(
"connections", "provider"
):
op.alter_column(
"connections",
"native_provider",
new_column_name="litellm_provider",
new_column_name="provider",
existing_type=sa.String(length=100),
existing_nullable=True,
)
elif not _column_exists("connections", "litellm_provider"):
op.alter_column(
"connections",
"provider",
existing_type=sa.String(length=100),
nullable=False,
)
elif not _column_exists("connections", "provider"):
op.add_column(
"connections",
sa.Column("litellm_provider", sa.String(length=100), nullable=True),
sa.Column("provider", sa.String(length=100), nullable=False),
)
_drop_index_if_exists("connections", "ix_connections_protocol")
_drop_column_if_exists("connections", "protocol")
else:
op.create_table(
"connections",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("protocol", connection_protocol, nullable=False),
sa.Column("litellm_provider", sa.String(length=100), nullable=True),
sa.Column("provider", sa.String(length=100), nullable=False),
sa.Column("base_url", sa.String(length=500), nullable=True),
sa.Column("api_key", sa.String(), nullable=True),
sa.Column(
@ -131,18 +155,20 @@ def upgrade() -> None:
sa.PrimaryKeyConstraint("id"),
)
if _index_exists("connections", "ix_connections_native_provider") and not _index_exists(
"connections", "ix_connections_litellm_provider"
"connections", "ix_connections_provider"
):
op.execute(
"ALTER INDEX ix_connections_native_provider "
"RENAME TO ix_connections_litellm_provider"
"RENAME TO ix_connections_provider"
)
_create_index_if_missing("ix_connections_protocol", "connections", ["protocol"])
_create_index_if_missing(
"ix_connections_litellm_provider",
"connections",
["litellm_provider"],
)
if _index_exists("connections", "ix_connections_litellm_provider") and not _index_exists(
"connections", "ix_connections_provider"
):
op.execute(
"ALTER INDEX ix_connections_litellm_provider "
"RENAME TO ix_connections_provider"
)
_create_index_if_missing("ix_connections_provider", "connections", ["provider"])
_create_index_if_missing("ix_connections_scope", "connections", ["scope"])
if not _table_exists("models"):
@ -159,24 +185,11 @@ def upgrade() -> None:
server_default="DISCOVERED",
nullable=False,
),
sa.Column(
"capabilities",
postgresql.JSONB(astext_type=sa.Text()),
server_default=sa.text("'{}'::jsonb"),
nullable=False,
),
sa.Column(
"capabilities_declared",
postgresql.JSONB(astext_type=sa.Text()),
server_default=sa.text("'{}'::jsonb"),
nullable=False,
),
sa.Column(
"capabilities_verified",
postgresql.JSONB(astext_type=sa.Text()),
server_default=sa.text("'{}'::jsonb"),
nullable=False,
),
sa.Column("supports_chat", sa.Boolean(), nullable=True),
sa.Column("max_input_tokens", sa.Integer(), nullable=True),
sa.Column("supports_image_input", sa.Boolean(), nullable=True),
sa.Column("supports_tools", sa.Boolean(), nullable=True),
sa.Column("supports_image_generation", sa.Boolean(), nullable=True),
sa.Column(
"capabilities_override",
postgresql.JSONB(astext_type=sa.Text()),
@ -198,6 +211,24 @@ def upgrade() -> None:
"connection_id", "model_id", name="uq_models_connection_model_id"
),
)
else:
if not _column_exists("models", "supports_chat"):
op.add_column("models", sa.Column("supports_chat", sa.Boolean(), nullable=True))
if not _column_exists("models", "max_input_tokens"):
op.add_column("models", sa.Column("max_input_tokens", sa.Integer(), nullable=True))
if not _column_exists("models", "supports_image_input"):
op.add_column(
"models", sa.Column("supports_image_input", sa.Boolean(), nullable=True)
)
if not _column_exists("models", "supports_tools"):
op.add_column("models", sa.Column("supports_tools", sa.Boolean(), nullable=True))
if not _column_exists("models", "supports_image_generation"):
op.add_column(
"models", sa.Column("supports_image_generation", sa.Boolean(), nullable=True)
)
_drop_column_if_exists("models", "capabilities")
_drop_column_if_exists("models", "capabilities_declared")
_drop_column_if_exists("models", "capabilities_verified")
_create_index_if_missing("ix_models_connection_id", "models", ["connection_id"])
_create_index_if_missing("ix_models_model_id", "models", ["model_id"])
_create_index_if_missing("ix_models_billing_tier", "models", ["billing_tier"])
@ -206,6 +237,8 @@ def upgrade() -> None:
_add_searchspace_column_if_missing("image_gen_model_id")
_add_searchspace_column_if_missing("vision_model_id")
op.execute("DROP TYPE IF EXISTS connectionprotocol")
def downgrade() -> None:
op.drop_column("searchspaces", "vision_model_id")
@ -218,11 +251,9 @@ def downgrade() -> None:
op.drop_table("models")
op.drop_index(op.f("ix_connections_scope"), table_name="connections")
op.drop_index(op.f("ix_connections_litellm_provider"), table_name="connections")
op.drop_index(op.f("ix_connections_protocol"), table_name="connections")
op.drop_index(op.f("ix_connections_provider"), table_name="connections")
op.drop_table("connections")
bind = op.get_bind()
model_source.drop(bind, checkfirst=True)
connection_scope.drop(bind, checkfirst=True)
connection_protocol.drop(bind, checkfirst=True)

View file

@ -29,6 +29,7 @@ from app.services.auto_model_pin_service import (
auto_model_candidates,
choose_auto_model_candidate,
)
from app.services.model_capabilities import has_capability
from app.services.model_resolver import to_litellm
from app.utils.signed_image_urls import generate_image_token
@ -146,9 +147,7 @@ def create_generate_image_tool(
if config_id < 0:
global_model = _get_global_model(config_id)
if not global_model or not (
global_model.get("capabilities") or {}
).get("image_gen"):
if not global_model or not has_capability(global_model, "image_gen"):
err = f"Image generation model {config_id} not found"
return _failed({"error": err}, error=err)
global_connection = _get_global_connection(
@ -191,7 +190,7 @@ def create_generate_image_tool(
):
err = f"Image generation model {config_id} not found"
return _failed({"error": err}, error=err)
if not (db_model.capabilities or {}).get("image_gen"):
if not has_capability(db_model, "image_gen"):
err = f"Model {config_id} is not image-generation capable"
return _failed({"error": err}, error=err)

View file

@ -49,16 +49,19 @@ def _sanitize_messages(messages: list[BaseMessage]) -> list[BaseMessage]:
reject the blank text. The OpenAI spec says ``content`` should be
``null`` when an assistant message only carries tool calls.
"""
sanitized: list[BaseMessage] = []
for msg in messages:
if isinstance(msg.content, list):
msg.content = _sanitize_content(msg.content)
next_msg = msg.model_copy(deep=True)
if isinstance(next_msg.content, list):
next_msg.content = _sanitize_content(next_msg.content)
if (
isinstance(msg, AIMessage)
and (not msg.content or msg.content == "")
and getattr(msg, "tool_calls", None)
isinstance(next_msg, AIMessage)
and (not next_msg.content or next_msg.content == "")
and getattr(next_msg, "tool_calls", None)
):
msg.content = None # type: ignore[assignment]
return messages
next_msg.content = None # type: ignore[assignment]
sanitized.append(next_msg)
return sanitized
class SanitizedChatLiteLLM(ChatLiteLLM):
@ -89,6 +92,22 @@ class SanitizedChatLiteLLM(ChatLiteLLM):
):
yield chunk
async def _agenerate(
self,
messages: list[BaseMessage],
stop: list[str] | None = None,
run_manager: AsyncCallbackManagerForLLMRun | None = None,
stream: bool | None = None,
**kwargs: Any,
) -> ChatResult:
return await super()._agenerate(
_sanitize_messages(messages),
stop=stop,
run_manager=run_manager,
stream=stream,
**kwargs,
)
def _attach_model_profile(llm: ChatLiteLLM, model_string: str) -> None:
"""Attach a ``profile`` dict to ChatLiteLLM with model context metadata."""
@ -210,7 +229,7 @@ class AgentConfig:
# BYOK rows have no curated flag; ask LiteLLM (default-allow on
# unknown). The streaming safety net still blocks explicit text-only.
supports_image_input=derive_supports_image_input(
litellm_provider=provider_value.lower(),
provider=provider_value.lower(),
model_name=config.model_name,
base_model=base_model,
custom_provider=config.custom_provider,
@ -229,7 +248,7 @@ class AgentConfig:
system_instructions = yaml_config.get("system_instructions", "")
provider = yaml_config.get("litellm_provider", "")
provider = yaml_config.get("provider") or yaml_config.get("litellm_provider", "")
model_name = yaml_config.get("model_name", "")
custom_provider = yaml_config.get("custom_provider")
litellm_params = yaml_config.get("litellm_params") or {}
@ -245,7 +264,7 @@ class AgentConfig:
supports_image_input = bool(yaml_config.get("supports_image_input"))
else:
supports_image_input = derive_supports_image_input(
litellm_provider=provider,
provider=provider,
model_name=model_name,
base_model=base_model,
custom_provider=custom_provider,
@ -396,8 +415,8 @@ def create_chat_litellm_from_config(llm_config: dict) -> ChatLiteLLM | None:
if llm_config.get("custom_provider"):
model_string = f"{llm_config['custom_provider']}/{llm_config['model_name']}"
else:
litellm_provider = llm_config.get("litellm_provider", "openai")
model_string = f"{litellm_provider}/{llm_config['model_name']}"
provider = llm_config.get("provider") or llm_config.get("litellm_provider", "openai")
model_string = f"{provider}/{llm_config['model_name']}"
litellm_kwargs = {
"model": model_string,

View file

@ -33,7 +33,6 @@ from app.config import (
initialize_llm_router,
initialize_openrouter_integration,
initialize_pricing_registration,
initialize_vision_llm_router,
)
from app.db import User, create_db_and_tables, get_async_session
from app.exceptions import GENERIC_5XX_MESSAGE, ISSUES_URL, SurfSenseError
@ -622,7 +621,6 @@ async def lifespan(app: FastAPI):
initialize_pricing_registration()
initialize_llm_router()
initialize_image_gen_router()
initialize_vision_llm_router()
# Phase 1.7 — JIT warmup. Bounded so a stuck warmup never delays
# worker readiness. ``shield`` so Uvicorn cancelling startup

View file

@ -115,14 +115,12 @@ def init_worker(**kwargs):
initialize_llm_router,
initialize_openrouter_integration,
initialize_pricing_registration,
initialize_vision_llm_router,
)
initialize_openrouter_integration()
initialize_pricing_registration()
initialize_llm_router()
initialize_image_gen_router()
initialize_vision_llm_router()
# Celery configuration, sourced from the central Config singleton

View file

@ -103,7 +103,7 @@ def load_global_llm_configs():
else None
)
cfg["supports_image_input"] = derive_supports_image_input(
litellm_provider=cfg.get("litellm_provider"),
provider=cfg.get("provider") or cfg.get("litellm_provider"),
model_name=cfg.get("model_name"),
base_model=base_model,
custom_provider=cfg.get("custom_provider"),
@ -122,7 +122,7 @@ def load_global_llm_configs():
# Stamp Auto (Fastest) ranking metadata. YAML configs are always
# Tier A — operator-curated, locked first when premium-eligible.
# The OpenRouter refresh tick later re-stamps health for any cfg
# whose litellm_provider == "openrouter" via _enrich_health.
# whose provider == "openrouter" via _enrich_health.
try:
from app.services.quality_score import static_score_yaml
@ -132,7 +132,7 @@ def load_global_llm_configs():
cfg["quality_score_static"] = static_q
cfg["quality_score"] = static_q
cfg["quality_score_health"] = None
# YAML cfgs whose litellm_provider is openrouter are also subject
# YAML cfgs whose provider is openrouter are also subject
# to health gating against their own /endpoints data — a
# hand-picked dead OR model is still dead. _enrich_health
# re-stamps health_gated for them on the next refresh tick.
@ -362,8 +362,8 @@ def initialize_openrouter_integration():
else:
print("Info: OpenRouter integration enabled but no models fetched")
# Image generation + vision LLM emissions are opt-in (issue L).
# Both reuse the catalogue already cached by ``service.initialize``
# Image generation emissions reuse the catalogue already cached by
# ``service.initialize``
# so we don't make additional network calls here.
if settings.get("image_generation_enabled"):
try:
@ -377,18 +377,6 @@ def initialize_openrouter_integration():
except Exception as e:
print(f"Warning: Failed to inject OpenRouter image-gen configs: {e}")
if settings.get("vision_enabled"):
try:
vision_configs = service.get_vision_llm_configs()
if vision_configs:
config.GLOBAL_VISION_LLM_CONFIGS.extend(vision_configs)
print(
f"Info: OpenRouter integration added {len(vision_configs)} "
f"vision LLM models"
)
except Exception as e:
print(f"Warning: Failed to inject OpenRouter vision-LLM configs: {e}")
refresh_global_model_catalog()
except Exception as e:
print(f"Warning: Failed to initialize OpenRouter integration: {e}")
@ -399,7 +387,6 @@ def materialize_global_configs():
return materialize_global_model_catalog(
chat_configs=getattr(config, "GLOBAL_LLM_CONFIGS", []),
vision_configs=getattr(config, "GLOBAL_VISION_LLM_CONFIGS", []),
image_configs=getattr(config, "GLOBAL_IMAGE_GEN_CONFIGS", []),
)
@ -493,29 +480,9 @@ def initialize_image_gen_router():
def initialize_vision_llm_router():
vision_configs = load_global_vision_llm_configs()
# Reuse the router settings already parsed at Config construction. The
# *configs* list is intentionally re-read from YAML (it must exclude the
# OpenRouter-injected dynamic models held in config.GLOBAL_VISION_LLM_CONFIGS).
router_settings = config.VISION_LLM_ROUTER_SETTINGS
if not vision_configs:
print(
"Info: No global vision LLM configs found, "
"Vision LLM Auto mode will not be available"
)
return
try:
from app.services.vision_llm_router_service import VisionLLMRouterService
VisionLLMRouterService.initialize(vision_configs, router_settings)
print(
f"Info: Vision LLM Router initialized with {len(vision_configs)} models "
f"(strategy: {router_settings.get('routing_strategy', 'usage-based-routing')})"
)
except Exception as e:
print(f"Warning: Failed to initialize Vision LLM Router: {e}")
# Retired: vision Auto now uses shared capability-filtered model selection
# over GLOBAL/BYOK chat models with supports_image_input=true.
return
class Config:
@ -874,7 +841,6 @@ class Config:
GLOBAL_CONNECTIONS, GLOBAL_MODELS = _materialize_global_model_catalog(
chat_configs=GLOBAL_LLM_CONFIGS,
vision_configs=GLOBAL_VISION_LLM_CONFIGS,
image_configs=GLOBAL_IMAGE_GEN_CONFIGS,
)
del _materialize_global_model_catalog

View file

@ -280,12 +280,6 @@ class VisionProvider(StrEnum):
CUSTOM = "CUSTOM"
class ConnectionProtocol(StrEnum):
OLLAMA = "OLLAMA"
OPENAI_COMPATIBLE = "OPENAI_COMPATIBLE"
ANTHROPIC = "ANTHROPIC"
class ConnectionScope(StrEnum):
GLOBAL = "GLOBAL"
SEARCH_SPACE = "SEARCH_SPACE"
@ -1662,8 +1656,7 @@ class Report(BaseModel, TimestampMixin):
class Connection(BaseModel, TimestampMixin):
__tablename__ = "connections"
protocol = Column(SQLAlchemyEnum(ConnectionProtocol), nullable=False, index=True)
litellm_provider = Column(String(100), nullable=True, index=True)
provider = Column(String(100), nullable=False, index=True)
base_url = Column(String(500), nullable=True)
api_key = Column(String, nullable=True)
extra = Column(JSONB, nullable=False, default=dict, server_default="{}")
@ -1715,9 +1708,11 @@ class Model(BaseModel, TimestampMixin):
default=ModelSource.DISCOVERED,
server_default=ModelSource.DISCOVERED.value,
)
capabilities = Column(JSONB, nullable=False, default=dict, server_default="{}")
capabilities_declared = Column(JSONB, nullable=False, default=dict, server_default="{}")
capabilities_verified = Column(JSONB, nullable=False, default=dict, server_default="{}")
supports_chat = Column(Boolean, nullable=True)
max_input_tokens = Column(Integer, nullable=True)
supports_image_input = Column(Boolean, nullable=True)
supports_tools = Column(Boolean, nullable=True)
supports_image_generation = Column(Boolean, nullable=True)
capabilities_override = Column(JSONB, nullable=False, default=dict, server_default="{}")
embedding_dimension = Column(Integer, nullable=True)
enabled = Column(Boolean, nullable=False, default=True, server_default="true")

View file

@ -132,7 +132,7 @@ async def list_anonymous_models():
id=cfg.get("id", 0),
name=cfg.get("name", ""),
description=cfg.get("description"),
provider=cfg.get("litellm_provider", ""),
provider=cfg.get("provider") or cfg.get("litellm_provider", ""),
model_name=cfg.get("model_name", ""),
billing_tier=cfg.get("billing_tier", "free"),
is_premium=cfg.get("billing_tier", "free") == "premium",
@ -161,7 +161,7 @@ async def get_anonymous_model(slug: str):
id=cfg.get("id", 0),
name=cfg.get("name", ""),
description=cfg.get("description"),
provider=cfg.get("litellm_provider", ""),
provider=cfg.get("provider") or cfg.get("litellm_provider", ""),
model_name=cfg.get("model_name", ""),
billing_tier=cfg.get("billing_tier", "free"),
is_premium=cfg.get("billing_tier", "free") == "premium",

View file

@ -52,6 +52,7 @@ from app.services.auto_model_pin_service import (
choose_auto_model_candidate,
)
from app.services.model_resolver import to_litellm
from app.services.model_capabilities import has_capability
from app.users import current_active_user
from app.utils.rbac import check_permission
from app.utils.signed_image_urls import verify_image_token
@ -166,7 +167,7 @@ async def _execute_image_generation(
if config_id < 0:
global_model = _get_global_model(config_id)
if not global_model or not (global_model.get("capabilities") or {}).get("image_gen"):
if not global_model or not has_capability(global_model, "image_gen"):
raise ValueError(f"Global image generation model {config_id} not found")
global_connection = _get_global_connection(global_model["connection_id"])
if not global_connection:
@ -200,7 +201,7 @@ async def _execute_image_generation(
raise ValueError(f"Image generation model {config_id} not found")
if conn.user_id is not None and conn.user_id != search_space.user_id:
raise ValueError(f"Image generation model {config_id} not found")
if not (db_model.capabilities or {}).get("image_gen"):
if not has_capability(db_model, "image_gen"):
raise ValueError(f"Model {config_id} is not image-generation capable")
model_string, resolved_kwargs = to_litellm(
@ -272,7 +273,7 @@ async def get_global_image_gen_configs(
"id": cfg.get("id"),
"name": cfg.get("name"),
"description": cfg.get("description"),
"provider": cfg.get("litellm_provider"),
"provider": cfg.get("provider") or cfg.get("litellm_provider"),
"custom_provider": cfg.get("custom_provider"),
"model_name": cfg.get("model_name"),
"api_base": cfg.get("api_base") or None,

View file

@ -8,7 +8,6 @@ from sqlalchemy.orm import selectinload
from app.config import config
from app.db import (
Connection,
ConnectionProtocol,
ConnectionScope,
Model,
ModelSource,
@ -22,6 +21,7 @@ from app.schemas import (
ConnectionRead,
ConnectionUpdate,
ModelCreate,
ModelProviderRead,
ModelRead,
ModelRolesRead,
ModelRolesUpdate,
@ -34,6 +34,8 @@ from app.services.model_connection_service import (
persist_verification,
test_model,
)
from app.services.model_capabilities import has_capability
from app.services.provider_registry import REGISTRY
from app.users import current_active_user
from app.utils.rbac import check_permission
@ -41,16 +43,6 @@ router = APIRouter()
logger = logging.getLogger(__name__)
def _default_litellm_provider(protocol: ConnectionProtocol | str) -> str:
protocol_value = getattr(protocol, "value", str(protocol))
defaults = {
ConnectionProtocol.OLLAMA.value: "ollama_chat",
ConnectionProtocol.ANTHROPIC.value: "anthropic",
ConnectionProtocol.OPENAI_COMPATIBLE.value: "openai",
}
return defaults.get(protocol_value, "openai")
def _model_read(model: Model | dict) -> ModelRead:
return ModelRead.model_validate(model)
@ -68,8 +60,7 @@ def _connection_read(conn: Connection | dict, models: list[Model | dict] | None
return ConnectionRead(
id=conn.id,
protocol=conn.protocol,
litellm_provider=conn.litellm_provider,
provider=conn.provider,
base_url=conn.base_url,
extra=conn.extra or {},
scope=conn.scope,
@ -85,6 +76,60 @@ def _connection_read(conn: Connection | dict, models: list[Model | dict] | None
)
def _apply_model_facts(model: Model, facts: dict) -> None:
model.supports_chat = facts.get("supports_chat")
model.max_input_tokens = facts.get("max_input_tokens")
model.supports_image_input = facts.get("supports_image_input")
model.supports_tools = facts.get("supports_tools")
model.supports_image_generation = facts.get("supports_image_generation")
def _default_model_for(models: list[Model], capability: str) -> int | None:
for model in models:
if model.enabled and has_capability(model, capability):
return model.id
return None
async def _default_unset_roles(
session: AsyncSession,
conn: Connection,
models: list[Model],
) -> None:
if conn.scope != ConnectionScope.SEARCH_SPACE or conn.search_space_id is None:
return
search_space = await _get_search_space(session, conn.search_space_id)
if search_space.chat_model_id is None:
search_space.chat_model_id = _default_model_for(models, "chat")
if search_space.vision_model_id is None:
vision_default = None
if search_space.chat_model_id:
chat_model = next((m for m in models if m.id == search_space.chat_model_id), None)
if chat_model and has_capability(chat_model, "vision"):
vision_default = chat_model.id
search_space.vision_model_id = vision_default or _default_model_for(models, "vision")
if search_space.image_gen_model_id is None:
search_space.image_gen_model_id = _default_model_for(models, "image_gen")
@router.get("/model-providers", response_model=list[ModelProviderRead])
async def list_model_providers(user: User = Depends(current_active_user)):
del user
local_only = {"ollama_chat", "lm_studio"}
return [
ModelProviderRead(
provider=provider,
transport=spec.transport.value,
discovery=spec.discovery,
default_base_url=spec.default_base_url,
base_url_required=spec.base_url_required,
auth_style=spec.auth_style,
local_only=provider in local_only,
)
for provider, spec in sorted(REGISTRY.items())
]
async def _get_search_space(session: AsyncSession, search_space_id: int) -> SearchSpace:
result = await session.execute(select(SearchSpace).where(SearchSpace.id == search_space_id))
search_space = result.scalars().first()
@ -180,8 +225,6 @@ async def create_connection(
"You don't have permission to create model connections in this search space",
)
payload = data.model_dump(exclude={"search_space_id"})
if not payload.get("litellm_provider"):
payload["litellm_provider"] = _default_litellm_provider(data.protocol)
conn = Connection(
**payload,
@ -254,24 +297,21 @@ async def discover_connection_models(
model_id=item["model_id"],
display_name=item.get("display_name"),
source=item["source"],
capabilities=item["capabilities"],
capabilities_declared=item["capabilities"],
capabilities_verified={},
capabilities_override={},
enabled=False,
catalog=item.get("metadata") or {},
)
_apply_model_facts(db_model, item)
session.add(db_model)
else:
db_model.display_name = item.get("display_name") or db_model.display_name
db_model.capabilities_declared = item["capabilities"]
db_model.capabilities = {
**item["capabilities"],
**(db_model.capabilities_override or {}),
}
_apply_model_facts(db_model, item)
db_model.catalog = item.get("metadata") or db_model.catalog
await session.commit()
conn = await _load_connection(session, connection_id)
await _default_unset_roles(session, conn, list(conn.models))
await session.commit()
conn = await _load_connection(session, connection_id)
return [_model_read(model) for model in conn.models]
@ -297,16 +337,17 @@ async def add_manual_model(
model_id=model_id,
display_name=data.display_name or None,
source=ModelSource.MANUAL,
capabilities=capabilities,
capabilities_declared=capabilities,
capabilities_verified={},
capabilities_override={},
enabled=True,
catalog={},
)
_apply_model_facts(model, capabilities)
session.add(model)
await session.commit()
await session.refresh(model)
conn = await _load_connection(session, connection_id)
await _default_unset_roles(session, conn, list(conn.models))
await session.commit()
return _model_read(model)
@ -327,11 +368,6 @@ async def update_model(
update = data.model_dump(exclude_unset=True)
for key, value in update.items():
setattr(model, key, value)
if "capabilities_override" in update:
model.capabilities = {
**(model.capabilities_declared or {}),
**(model.capabilities_override or {}),
}
await session.commit()
await session.refresh(model)
return _model_read(model)

View file

@ -1741,12 +1741,11 @@ async def handle_new_chat(
if not search_space:
raise HTTPException(status_code=404, detail="Search space not found")
# Use agent_llm_id from search space for chat operations
# Positive IDs load from NewLLMConfig database table
# Negative IDs load from YAML global configs
# Falls back to -1 (first global config) if not configured
# Use the converged model-connections role for chat operations.
# Positive IDs load Model + Connection rows; negative IDs load
# virtual GLOBAL models; 0 means Auto.
llm_config_id = (
search_space.agent_llm_id if search_space.agent_llm_id is not None else -1
search_space.chat_model_id if search_space.chat_model_id is not None else 0
)
# Release the read-transaction so we don't hold ACCESS SHARE locks
@ -2228,7 +2227,7 @@ async def regenerate_response(
raise HTTPException(status_code=404, detail="Search space not found")
llm_config_id = (
search_space.agent_llm_id if search_space.agent_llm_id is not None else -1
search_space.chat_model_id if search_space.chat_model_id is not None else 0
)
# Release the read-transaction so we don't hold ACCESS SHARE locks
@ -2393,7 +2392,7 @@ async def resume_chat(
raise HTTPException(status_code=404, detail="Search space not found")
llm_config_id = (
search_space.agent_llm_id if search_space.agent_llm_id is not None else -1
search_space.chat_model_id if search_space.chat_model_id is not None else 0
)
decisions = [d.model_dump() for d in request.decisions]

View file

@ -57,7 +57,7 @@ def _serialize_byok_config(config: NewLLMConfig) -> NewLLMConfigRead:
litellm_params.get("base_model") if isinstance(litellm_params, dict) else None
)
supports_image_input = derive_supports_image_input(
litellm_provider=provider_value.lower(),
provider=provider_value.lower(),
model_name=config.model_name,
base_model=base_model,
custom_provider=config.custom_provider,
@ -147,7 +147,7 @@ async def get_global_new_llm_configs(
else None
)
supports_image_input = derive_supports_image_input(
litellm_provider=cfg.get("litellm_provider"),
provider=cfg.get("provider") or cfg.get("litellm_provider"),
model_name=cfg.get("model_name"),
base_model=cfg_base_model,
custom_provider=cfg.get("custom_provider"),
@ -157,7 +157,7 @@ async def get_global_new_llm_configs(
"id": cfg.get("id"),
"name": cfg.get("name"),
"description": cfg.get("description"),
"provider": cfg.get("litellm_provider"),
"provider": cfg.get("provider") or cfg.get("litellm_provider"),
"custom_provider": cfg.get("custom_provider"),
"model_name": cfg.get("model_name"),
"api_base": cfg.get("api_base") or None,

View file

@ -419,7 +419,7 @@ async def _get_llm_config_by_id(
"id": cfg.get("id"),
"name": cfg.get("name"),
"description": cfg.get("description"),
"provider": cfg.get("litellm_provider"),
"provider": cfg.get("provider") or cfg.get("litellm_provider"),
"custom_provider": cfg.get("custom_provider"),
"model_name": cfg.get("model_name"),
"api_base": cfg.get("api_base"),
@ -490,7 +490,7 @@ async def _get_image_gen_config_by_id(
"id": cfg.get("id"),
"name": cfg.get("name"),
"description": cfg.get("description"),
"provider": cfg.get("litellm_provider"),
"provider": cfg.get("provider") or cfg.get("litellm_provider"),
"custom_provider": cfg.get("custom_provider"),
"model_name": cfg.get("model_name"),
"api_base": cfg.get("api_base") or None,
@ -550,7 +550,7 @@ async def _get_vision_llm_config_by_id(
"id": cfg.get("id"),
"name": cfg.get("name"),
"description": cfg.get("description"),
"provider": cfg.get("litellm_provider"),
"provider": cfg.get("provider") or cfg.get("litellm_provider"),
"custom_provider": cfg.get("custom_provider"),
"model_name": cfg.get("model_name"),
"api_base": cfg.get("api_base") or None,

View file

@ -96,7 +96,7 @@ async def get_global_vision_llm_configs(
"id": cfg.get("id"),
"name": cfg.get("name"),
"description": cfg.get("description"),
"provider": cfg.get("litellm_provider"),
"provider": cfg.get("provider") or cfg.get("litellm_provider"),
"custom_provider": cfg.get("custom_provider"),
"model_name": cfg.get("model_name"),
"api_base": cfg.get("api_base") or None,

View file

@ -49,6 +49,7 @@ from .model_connections import (
ConnectionRead,
ConnectionUpdate,
ModelCreate,
ModelProviderRead,
ModelRead,
ModelRolesRead,
ModelRolesUpdate,

View file

@ -4,7 +4,7 @@ from typing import Any
from pydantic import BaseModel, ConfigDict, Field
from app.db import ConnectionProtocol, ConnectionScope, ModelSource
from app.db import ConnectionScope, ModelSource
class ModelRead(BaseModel):
@ -13,9 +13,11 @@ class ModelRead(BaseModel):
model_id: str
display_name: str | None = None
source: ModelSource | str
capabilities: dict[str, Any]
capabilities_declared: dict[str, Any] = Field(default_factory=dict)
capabilities_verified: dict[str, Any] = Field(default_factory=dict)
supports_chat: bool | None = None
max_input_tokens: int | None = None
supports_image_input: bool | None = None
supports_tools: bool | None = None
supports_image_generation: bool | None = None
capabilities_override: dict[str, Any] = Field(default_factory=dict)
embedding_dimension: int | None = None
enabled: bool
@ -28,8 +30,7 @@ class ModelRead(BaseModel):
class ConnectionRead(BaseModel):
id: int
protocol: ConnectionProtocol | str
litellm_provider: str | None = None
provider: str
base_url: str | None = None
extra: dict[str, Any] = Field(default_factory=dict)
scope: ConnectionScope | str
@ -47,8 +48,7 @@ class ConnectionRead(BaseModel):
class ConnectionCreate(BaseModel):
protocol: ConnectionProtocol
litellm_provider: str | None = Field(None, max_length=100)
provider: str = Field(..., max_length=100)
base_url: str | None = Field(None, max_length=500)
api_key: str | None = None
extra: dict[str, Any] = Field(default_factory=dict)
@ -58,7 +58,7 @@ class ConnectionCreate(BaseModel):
class ConnectionUpdate(BaseModel):
litellm_provider: str | None = Field(None, max_length=100)
provider: str | None = Field(None, max_length=100)
base_url: str | None = Field(None, max_length=500)
api_key: str | None = None
extra: dict[str, Any] | None = None
@ -79,9 +79,24 @@ class ModelCreate(BaseModel):
class ModelUpdate(BaseModel):
display_name: str | None = Field(None, max_length=255)
enabled: bool | None = None
supports_chat: bool | None = None
max_input_tokens: int | None = None
supports_image_input: bool | None = None
supports_tools: bool | None = None
supports_image_generation: bool | None = None
capabilities_override: dict[str, Any] | None = None
class ModelProviderRead(BaseModel):
provider: str
transport: str
discovery: str
default_base_url: str | None = None
base_url_required: bool
auth_style: str
local_only: bool = False
class VerifyConnectionResponse(BaseModel):
status: str
ok: bool

View file

@ -27,6 +27,7 @@ from sqlalchemy.orm import selectinload
from app.config import config
from app.db import Connection, Model, NewChatThread
from app.services.model_capabilities import has_capability
from app.services.quality_score import _QUALITY_TOP_K
from app.services.token_quota_service import TokenQuotaService
@ -62,18 +63,13 @@ def _is_usable_global_config(cfg: dict) -> bool:
return bool(
cfg.get("id") is not None
and cfg.get("model_name")
and cfg.get("litellm_provider")
and (cfg.get("provider") or cfg.get("litellm_provider"))
and cfg.get("api_key")
)
def _has_capability(model: dict | Model, capability: str) -> bool:
caps = (
model.get("capabilities", {})
if isinstance(model, dict)
else model.capabilities or {}
)
return bool(caps.get(capability))
return has_capability(model, capability)
def _prune_runtime_cooldowns(now_ts: float | None = None) -> None:
@ -196,7 +192,7 @@ def _cfg_supports_image_input(cfg: dict) -> bool:
else None
)
return derive_supports_image_input(
litellm_provider=cfg.get("litellm_provider"),
provider=cfg.get("provider") or cfg.get("litellm_provider"),
model_name=cfg.get("model_name"),
base_model=base_model,
custom_provider=cfg.get("custom_provider"),
@ -253,9 +249,13 @@ def _global_candidates(
"model_id": model.get("model_id"),
"source": "global",
"connection": connection,
"capabilities": model.get("capabilities") or {},
"supports_chat": model.get("supports_chat"),
"supports_image_input": model.get("supports_image_input"),
"supports_tools": model.get("supports_tools"),
"supports_image_generation": model.get("supports_image_generation"),
"capabilities_override": model.get("capabilities_override") or {},
"billing_tier": model.get("billing_tier", "free"),
"litellm_provider": connection.get("litellm_provider"),
"provider": connection.get("provider"),
"model_name": model.get("model_id"),
"auto_pin_tier": catalog.get("auto_pin_tier")
or cfg.get("auto_pin_tier")
@ -310,9 +310,13 @@ async def _db_candidates(
"model_id": model.model_id,
"source": "db",
"connection": conn,
"capabilities": model.capabilities or {},
"supports_chat": model.supports_chat,
"supports_image_input": model.supports_image_input,
"supports_tools": model.supports_tools,
"supports_image_generation": model.supports_image_generation,
"capabilities_override": model.capabilities_override or {},
"billing_tier": "byok",
"litellm_provider": conn.litellm_provider,
"provider": conn.provider,
"model_name": model.model_id,
"auto_pin_tier": catalog.get("auto_pin_tier") or "BYOK",
"quality_score": catalog.get("quality_score") or 75,
@ -357,7 +361,7 @@ def _is_preferred_premium_auto_config(cfg: dict) -> bool:
return (
cfg.get("source") == "global"
and _tier_of(cfg) == "premium"
and str(cfg.get("litellm_provider", "")).lower() == "azure"
and str(cfg.get("provider", "")).lower() == "azure"
and str(cfg.get("model_name", "")).lower() == "gpt-5.4"
)

View file

@ -18,8 +18,7 @@ def _connection_key(conn: dict[str, Any]) -> tuple[Any, ...]:
# Deliberately includes api_key because two operator-owned credentials for
# the same provider/base can have different quota/rate limits upstream.
return (
conn.get("protocol"),
conn.get("litellm_provider"),
conn.get("provider"),
conn.get("base_url"),
conn.get("api_key"),
_freeze(conn.get("extra") or {}),
@ -34,16 +33,6 @@ def _freeze(value: Any) -> Any:
return value
def _capabilities_for(role: str, config: dict[str, Any]) -> dict[str, bool]:
return {
"chat": role == "chat",
"vision": role == "vision" or bool(config.get("supports_image_input")),
"image_gen": role == "image_gen",
"embedding": False,
"tools": bool(config.get("supports_tools", False)),
}
def _catalog_metadata(config: dict[str, Any]) -> dict[str, Any]:
return {
"billing_tier": config.get("billing_tier", "free"),
@ -72,7 +61,6 @@ def _catalog_metadata(config: dict[str, Any]) -> dict[str, Any]:
def materialize_global_model_catalog(
*,
chat_configs: list[dict[str, Any]],
vision_configs: list[dict[str, Any]],
image_configs: list[dict[str, Any]],
) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]:
connections: list[dict[str, Any]] = []
@ -109,9 +97,13 @@ def materialize_global_model_catalog(
"model_id": config["model_name"],
"display_name": config.get("name") or config["model_name"],
"source": "MANUAL",
"capabilities": _capabilities_for(role, config),
"capabilities_declared": _capabilities_for(role, config),
"capabilities_verified": _capabilities_for(role, config),
"supports_chat": role == "chat",
"max_input_tokens": config.get("max_input_tokens"),
"supports_image_input": (
role == "chat" and bool(config.get("supports_image_input"))
),
"supports_tools": bool(config.get("supports_tools", False)),
"supports_image_generation": role == "image_gen",
"capabilities_override": {},
"embedding_dimension": None,
"enabled": True,
@ -125,10 +117,6 @@ def materialize_global_model_catalog(
if cfg.get("is_auto_mode"):
continue
add_config(cfg, "chat")
for cfg in vision_configs:
if cfg.get("is_auto_mode"):
continue
add_config(cfg, "vision")
for cfg in image_configs:
if cfg.get("is_auto_mode"):
continue

View file

@ -15,6 +15,7 @@ from app.services.auto_model_pin_service import (
choose_auto_model_candidate,
)
from app.services.llm_router_service import AUTO_MODE_ID, ChatLiteLLMRouter, is_auto_mode
from app.services.model_capabilities import has_capability
from app.services.model_resolver import native_connection_from_config, to_litellm
from app.services.token_tracking_service import token_tracker
@ -76,7 +77,7 @@ def _legacy_config_connection(
api_version: str | None = None,
) -> tuple[str, dict]:
cfg = {
"litellm_provider": provider.lower(),
"provider": provider.lower(),
"model_name": model_name,
"api_key": api_key,
"api_base": api_base,
@ -136,12 +137,7 @@ def get_global_connection(connection_id: int) -> dict | None:
def _has_capability(model: dict | Model, capability: str) -> bool:
caps = (
model.get("capabilities", {})
if isinstance(model, dict)
else model.capabilities or {}
)
return bool(caps.get(capability))
return has_capability(model, capability)
def _chat_litellm_from_resolved(
@ -420,8 +416,6 @@ async def get_vision_llm(
unwrapped they don't consume premium credit (issue M).
"""
from app.services.quota_checked_vision_llm import QuotaCheckedVisionLLM
from app.services.vision_llm_router_service import is_vision_auto_mode
try:
result = await session.execute(
select(SearchSpace).where(SearchSpace.id == search_space_id)
@ -468,7 +462,7 @@ async def get_vision_llm(
logger.error(f"No vision LLM configured for search space {search_space_id}")
return None
if is_vision_auto_mode(config_id):
if config_id == AUTO_MODE_ID:
candidates = await auto_model_candidates(
session,
search_space_id=search_space_id,

View file

@ -0,0 +1,36 @@
"""Override-aware model capability lookup."""
from __future__ import annotations
from collections.abc import Mapping
from typing import Any
CAPABILITY_FIELDS = {
"chat": "supports_chat",
"vision": "supports_image_input",
"image_gen": "supports_image_generation",
"tools": "supports_tools",
}
def _get_value(model: Any, key: str) -> Any:
if isinstance(model, Mapping):
return model.get(key)
return getattr(model, key, None)
def has_capability(model: Any, capability: str) -> bool:
field = CAPABILITY_FIELDS.get(capability)
if field is None:
return False
override = _get_value(model, "capabilities_override") or {}
if isinstance(override, Mapping) and field in override:
return bool(override[field])
if isinstance(override, Mapping) and capability in override:
return bool(override[capability])
return bool(_get_value(model, field))
__all__ = ["CAPABILITY_FIELDS", "has_capability"]

View file

@ -11,8 +11,10 @@ from typing import Any
import httpx
import litellm
from app.db import Connection, ConnectionProtocol, Model, ModelSource
from app.db import Connection, Model, ModelSource
from app.services.model_resolver import ensure_v1, to_litellm
from app.services.openrouter_model_normalizer import normalize_openrouter_models
from app.services.provider_registry import Transport, spec_for
logger = logging.getLogger(__name__)
@ -41,6 +43,16 @@ def _anthropic_headers(conn: Connection) -> dict[str, str]:
return headers
def _base_url_or_default(conn: Connection) -> str | None:
if conn.base_url:
return conn.base_url.rstrip("/")
if conn.provider == "openai":
return "https://api.openai.com/v1"
if conn.provider == "anthropic":
return "https://api.anthropic.com/v1"
return spec_for(conn.provider).default_base_url
def _docker_hint(url: str | None, exc_or_status: Any) -> str:
raw = str(exc_or_status)
if not url:
@ -61,32 +73,30 @@ def _docker_hint(url: str | None, exc_or_status: Any) -> str:
async def verify_connection(conn: Connection) -> VerifyResult:
if not conn.base_url:
spec = spec_for(conn.provider)
base_url = _base_url_or_default(conn)
if spec.base_url_required and not base_url:
return VerifyResult("UNREACHABLE", False, "Base URL is required.")
if conn.protocol == ConnectionProtocol.OLLAMA:
url = f"{conn.base_url.rstrip('/')}/api/version"
elif conn.protocol == ConnectionProtocol.OPENAI_COMPATIBLE:
url = f"{ensure_v1(conn.base_url)}/models"
elif conn.protocol == ConnectionProtocol.ANTHROPIC:
url = f"{conn.base_url.rstrip('/')}/models"
if spec.transport == Transport.OLLAMA and base_url:
url = f"{base_url.rstrip('/')}/api/version"
elif spec.discovery in {"openai_models", "openrouter"} and base_url:
url = f"{ensure_v1(base_url)}/models"
elif spec.discovery == "anthropic_models" and base_url:
url = f"{base_url.rstrip('/')}/models"
else:
return VerifyResult("UNREACHABLE", False, "Unsupported connection protocol.")
return VerifyResult("OK", True, "Connection uses provider-native authentication.")
try:
async with httpx.AsyncClient(timeout=VERIFY_TIMEOUT_SECONDS) as client:
headers = (
_anthropic_headers(conn)
if conn.protocol == ConnectionProtocol.ANTHROPIC
else _auth_headers(conn)
)
headers = _anthropic_headers(conn) if spec.auth_style == "x-api-key" else _auth_headers(conn)
response = await client.get(url, headers=headers)
if response.status_code in (401, 403):
return VerifyResult("AUTH_FAILED", False, "Authentication failed.")
if response.status_code == 404:
if conn.protocol == ConnectionProtocol.OLLAMA and url.endswith("/v1/models"):
if spec.transport == Transport.OLLAMA and url.endswith("/v1/models"):
message = "Ollama native API should not use /v1."
elif conn.protocol == ConnectionProtocol.OPENAI_COMPATIBLE:
elif spec.transport == Transport.OPENAI_COMPATIBLE:
message = "OpenAI-compatible servers should expose /v1/models."
else:
message = "Endpoint returned 404."
@ -94,11 +104,11 @@ async def verify_connection(conn: Connection) -> VerifyResult:
response.raise_for_status()
return VerifyResult("OK", True, "Connection verified.")
except httpx.ConnectError as exc:
return VerifyResult("UNREACHABLE", False, _docker_hint(conn.base_url, exc))
return VerifyResult("UNREACHABLE", False, _docker_hint(base_url, exc))
except httpx.TimeoutException as exc:
return VerifyResult("UNREACHABLE", False, f"Connection timed out: {exc}")
except httpx.HTTPError as exc:
return VerifyResult("UNREACHABLE", False, _docker_hint(conn.base_url, exc))
return VerifyResult("UNREACHABLE", False, _docker_hint(base_url, exc))
async def persist_verification(conn: Connection) -> VerifyResult:
@ -109,123 +119,193 @@ async def persist_verification(conn: Connection) -> VerifyResult:
return result
def _litellm_capabilities(model_string: str, model_id: str) -> dict[str, bool]:
capabilities = {
"chat": True,
"vision": False,
"tools": False,
"image_gen": False,
"embedding": False,
}
with contextlib.suppress(Exception):
capabilities["vision"] = bool(litellm.supports_vision(model=model_string))
with contextlib.suppress(Exception):
capabilities["tools"] = bool(litellm.supports_function_calling(model=model_string))
try:
info = litellm.model_cost.get(model_string) or litellm.model_cost.get(model_id) or {}
mode = str(info.get("mode") or "")
capabilities["embedding"] = mode == "embedding"
capabilities["image_gen"] = mode in {"image_generation", "image_generation_model"}
except Exception:
pass
return capabilities
def _allowlist(conn: Connection) -> set[str]:
"""Per-connection model-id allowlist stored in ``extra.model_ids``.
Empty/absent means "no restriction" (discover everything), mirroring
OpenWebUI's behaviour. A non-empty list restricts discovery to those ids —
essential for providers like OpenRouter that expose hundreds of models.
"""
raw = (conn.extra or {}).get("model_ids") or []
return {str(item).strip() for item in raw if str(item).strip()}
async def _discover_openai_shaped_models(conn: Connection, base_url: str | None) -> list[dict[str, Any]]:
if not base_url:
def _litellm_info(model_string: str, model_id: str) -> dict[str, Any]:
with contextlib.suppress(Exception):
info = litellm.get_model_info(model=model_string)
if isinstance(info, dict):
return info
return litellm.model_cost.get(model_string) or litellm.model_cost.get(model_id) or {}
def _classify_from_litellm(model_string: str, model_id: str) -> dict[str, Any]:
info = _litellm_info(model_string, model_id)
mode = info.get("mode")
supports_image_input = False
supports_tools = False
with contextlib.suppress(Exception):
supports_image_input = bool(litellm.supports_vision(model=model_string))
with contextlib.suppress(Exception):
supports_tools = bool(litellm.supports_function_calling(model=model_string))
return {
"supports_chat": mode in (None, "chat", "completion", "responses"),
"max_input_tokens": info.get("max_input_tokens") or info.get("max_tokens"),
"supports_image_input": supports_image_input,
"supports_tools": supports_tools,
"supports_image_generation": mode in {"image_generation", "image_generation_model"},
}
def derive_capabilities(conn: Connection, model_id: str, metadata: dict | None = None) -> dict[str, Any]:
metadata = metadata or {}
spec = spec_for(conn.provider)
model_string, _ = to_litellm(conn, model_id)
facts = _classify_from_litellm(model_string, model_id)
if spec.transport == Transport.OLLAMA:
caps = set(metadata.get("capabilities") or [])
details = metadata.get("details") or {}
facts.update(
{
"supports_chat": "embedding" not in caps,
"supports_image_input": "vision" in caps or facts["supports_image_input"],
"supports_tools": "tools" in caps or facts["supports_tools"],
"supports_image_generation": False,
"max_input_tokens": metadata.get("context_length")
or metadata.get("num_ctx")
or details.get("context_length")
or facts["max_input_tokens"],
}
)
return facts
async def _discover_openai_shaped_models(
conn: Connection, base_url: str | None
) -> list[dict[str, Any]]:
resolved_base_url = base_url or _base_url_or_default(conn)
if not resolved_base_url:
return []
url = f"{ensure_v1(base_url)}/models"
url = f"{ensure_v1(resolved_base_url)}/models"
async with httpx.AsyncClient(timeout=DISCOVERY_TIMEOUT_SECONDS) as client:
response = await client.get(url, headers=_auth_headers(conn))
response.raise_for_status()
return [
{
"model_id": item.get("id"),
"display_name": item.get("name") or item.get("id"),
"source": ModelSource.DISCOVERED,
"capabilities": derive_capabilities(conn, item.get("id"), item),
"metadata": item,
}
for item in response.json().get("data", [])
if item.get("id")
]
results: list[dict[str, Any]] = []
for item in response.json().get("data", []):
model_id = item.get("id")
if not model_id:
continue
results.append(
{
"model_id": model_id,
"display_name": item.get("name") or model_id,
"source": ModelSource.DISCOVERED,
**derive_capabilities(conn, model_id, item),
"metadata": item,
}
)
return results
async def _discover_anthropic_models(conn: Connection) -> list[dict[str, Any]]:
if not conn.base_url:
base_url = _base_url_or_default(conn)
if not base_url:
return []
url = f"{conn.base_url.rstrip('/')}/models"
url = f"{base_url.rstrip('/')}/models"
async with httpx.AsyncClient(timeout=DISCOVERY_TIMEOUT_SECONDS) as client:
response = await client.get(url, headers=_anthropic_headers(conn))
response.raise_for_status()
models = response.json().get("data", [])
return [
{
"model_id": item.get("id"),
"display_name": item.get("display_name") or item.get("id"),
"source": ModelSource.DISCOVERED,
"capabilities": derive_capabilities(conn, item.get("id"), item),
"metadata": item,
}
for item in models
if item.get("id")
]
results: list[dict[str, Any]] = []
for item in response.json().get("data", []):
model_id = item.get("id")
if not model_id:
continue
results.append(
{
"model_id": model_id,
"display_name": item.get("display_name") or model_id,
"source": ModelSource.DISCOVERED,
**derive_capabilities(conn, model_id, item),
"metadata": item,
}
)
return results
def derive_capabilities(conn: Connection, model_id: str, metadata: dict | None = None) -> dict[str, bool]:
metadata = metadata or {}
if conn.protocol == ConnectionProtocol.OLLAMA:
caps = metadata.get("capabilities") or []
capabilities = {
"chat": True,
"vision": "vision" in caps,
"tools": False,
"image_gen": False,
"embedding": "embedding" in caps,
}
return capabilities
async def _ollama_tags_then_show(conn: Connection) -> list[dict[str, Any]]:
if not conn.base_url:
return []
model_string, _ = to_litellm(conn, model_id)
return _litellm_capabilities(model_string, model_id)
base_url = conn.base_url.rstrip("/")
async with httpx.AsyncClient(timeout=DISCOVERY_TIMEOUT_SECONDS) as client:
response = await client.get(f"{base_url}/api/tags", headers=_auth_headers(conn))
response.raise_for_status()
models = response.json().get("models", [])
results: list[dict[str, Any]] = []
for item in models:
model_id = item.get("model") or item.get("name")
if not model_id:
continue
metadata = dict(item)
with contextlib.suppress(Exception):
show_response = await client.post(
f"{base_url}/api/show",
json={"model": model_id},
headers=_auth_headers(conn),
)
show_response.raise_for_status()
metadata.update(show_response.json())
results.append(
{
"model_id": model_id,
"display_name": item.get("name") or model_id,
"source": ModelSource.DISCOVERED,
**derive_capabilities(conn, model_id, metadata),
"metadata": metadata,
}
)
return results
async def _openrouter_models(conn: Connection) -> list[dict[str, Any]]:
base_url = _base_url_or_default(conn) or "https://openrouter.ai/api/v1"
async with httpx.AsyncClient(timeout=DISCOVERY_TIMEOUT_SECONDS) as client:
response = await client.get(f"{ensure_v1(base_url)}/models", headers=_auth_headers(conn))
response.raise_for_status()
return normalize_openrouter_models(response.json().get("data", []))
def _litellm_static_models(conn: Connection) -> list[dict[str, Any]]:
provider = conn.provider
prefix = spec_for(provider).litellm_prefix or provider
results: list[dict[str, Any]] = []
for model_string, metadata in litellm.model_cost.items():
if not isinstance(model_string, str) or not model_string.startswith(f"{prefix}/"):
continue
model_id = model_string.split("/", 1)[1]
results.append(
{
"model_id": model_id,
"display_name": metadata.get("display_name") or model_id,
"source": ModelSource.DISCOVERED,
**_classify_from_litellm(model_string, model_id),
"metadata": metadata,
}
)
return results
async def discover_models(conn: Connection) -> list[dict[str, Any]]:
allowlist = _allowlist(conn)
spec = spec_for(conn.provider)
if conn.protocol == ConnectionProtocol.OLLAMA:
url = f"{conn.base_url.rstrip('/')}/api/tags"
async with httpx.AsyncClient(timeout=DISCOVERY_TIMEOUT_SECONDS) as client:
response = await client.get(url, headers=_auth_headers(conn))
response.raise_for_status()
models = response.json().get("models", [])
results = [
{
"model_id": item.get("model") or item.get("name"),
"display_name": item.get("name") or item.get("model"),
"source": ModelSource.DISCOVERED,
"capabilities": derive_capabilities(conn, item.get("model") or item.get("name"), item),
"metadata": item,
}
for item in models
if item.get("model") or item.get("name")
]
elif conn.protocol == ConnectionProtocol.OPENAI_COMPATIBLE:
results = await _discover_openai_shaped_models(conn, conn.base_url)
elif conn.protocol == ConnectionProtocol.ANTHROPIC:
if spec.discovery == "ollama":
results = await _ollama_tags_then_show(conn)
elif spec.discovery == "openrouter":
results = await _openrouter_models(conn)
elif spec.discovery == "anthropic_models":
results = await _discover_anthropic_models(conn)
elif spec.discovery == "openai_models":
results = await _discover_openai_shaped_models(conn, conn.base_url)
elif spec.discovery == "static":
results = _litellm_static_models(conn)
else:
results = []
@ -246,10 +326,7 @@ async def test_model(conn: Connection, model: Model) -> VerifyResult:
except Exception as exc:
return VerifyResult("UNREACHABLE", False, str(exc))
model.capabilities_verified = {
**(model.capabilities_verified or {}),
"chat": True,
}
model.supports_chat = True
return VerifyResult("OK", True, "Model test succeeded.")

View file

@ -12,6 +12,8 @@ from pathlib import Path
import httpx
from app.services.openrouter_model_normalizer import normalize_openrouter_models
logger = logging.getLogger(__name__)
OPENROUTER_API_URL = "https://openrouter.ai/api/v1/models"
@ -121,26 +123,13 @@ def _process_models(raw_models: list[dict]) -> list[dict]:
"""
processed: list[dict] = []
for model in raw_models:
model_id: str = model.get("id", "")
name: str = model.get("name", "")
context_length = model.get("context_length")
for normalized in normalize_openrouter_models(raw_models):
model_id: str = normalized["model_id"]
name: str = normalized.get("display_name") or model_id
context_length = normalized.get("max_input_tokens")
if "/" not in model_id:
continue
if not _is_text_output_model(model):
continue
if not _supports_tool_calling(model):
continue
if not _has_sufficient_context(model):
continue
if not _is_allowed_model(model):
continue
provider_slug, model_name = model_id.split("/", 1)
context_window = _format_context_length(context_length)

View file

@ -12,9 +12,7 @@ from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
from app.db import Connection
PROTOCOL_OLLAMA = "OLLAMA"
PROTOCOL_OPENAI_COMPATIBLE = "OPENAI_COMPATIBLE"
PROTOCOL_ANTHROPIC = "ANTHROPIC"
from app.services.provider_registry import Transport, spec_for
def ensure_v1(base_url: str | None) -> str | None:
@ -32,47 +30,25 @@ def _conn_value(conn: Connection | Mapping[str, Any], key: str) -> Any:
return getattr(conn, key)
def _protocol_value(protocol: Any) -> str:
return getattr(protocol, "value", str(protocol))
def default_litellm_provider(protocol: Any) -> str:
protocol_value = _protocol_value(protocol)
defaults = {
PROTOCOL_OLLAMA: "ollama_chat",
PROTOCOL_ANTHROPIC: "anthropic",
PROTOCOL_OPENAI_COMPATIBLE: "openai",
}
return defaults.get(protocol_value, "openai")
def _execution_api_base(protocol: str, base_url: str | None) -> str | None:
del protocol
if not base_url:
return None
return base_url.rstrip("/")
def to_litellm(
conn: Connection | Mapping[str, Any],
model_id: str,
) -> tuple[str, dict[str, Any]]:
"""Return ``(model_string, litellm_kwargs)`` for any model role."""
protocol = _protocol_value(_conn_value(conn, "protocol"))
provider = _conn_value(conn, "provider")
base_url = _conn_value(conn, "base_url")
api_key = _conn_value(conn, "api_key")
litellm_provider = (
_conn_value(conn, "litellm_provider") or default_litellm_provider(protocol)
)
extra = _conn_value(conn, "extra") or {}
spec = spec_for(provider)
kwargs: dict[str, Any] = {}
if api_key:
kwargs["api_key"] = api_key
model_string = f"{litellm_provider}/{model_id}" if litellm_provider else model_id
api_base = _execution_api_base(protocol, base_url)
if api_base:
prefix = spec.litellm_prefix or str(provider)
model_string = f"{prefix}/{model_id}" if prefix else model_id
if base_url:
api_base = ensure_v1(base_url) if spec.transport == Transport.OPENAI_COMPATIBLE else base_url.rstrip("/")
kwargs["api_base"] = api_base
if api_version := extra.get("api_version"):
@ -84,11 +60,11 @@ def to_litellm(
def native_connection_from_config(config: Mapping[str, Any]) -> dict[str, Any]:
"""Build an in-memory connection mapping from a global config."""
protocol = str(config.get("protocol") or PROTOCOL_OPENAI_COMPATIBLE)
litellm_provider = str(
config.get("litellm_provider")
provider = str(
config.get("provider")
or config.get("litellm_provider")
or config.get("custom_provider")
or default_litellm_provider(protocol)
or "openai"
)
extra: dict[str, Any] = {
"litellm_params": config.get("litellm_params") or {},
@ -96,8 +72,7 @@ def native_connection_from_config(config: Mapping[str, Any]) -> dict[str, Any]:
if config.get("api_version"):
extra["api_version"] = config.get("api_version")
return {
"protocol": protocol,
"litellm_provider": litellm_provider,
"provider": provider,
"base_url": config.get("api_base") or None,
"api_key": config.get("api_key") or None,
"extra": extra,
@ -105,7 +80,6 @@ def native_connection_from_config(config: Mapping[str, Any]) -> dict[str, Any]:
__all__ = [
"default_litellm_provider",
"ensure_v1",
"native_connection_from_config",
"to_litellm",

View file

@ -29,6 +29,13 @@ from app.services.quality_score import (
aggregate_health,
static_score_or,
)
from app.services.openrouter_model_normalizer import (
is_allowed_model as _shared_is_allowed_model,
is_compatible_provider as _shared_is_compatible_provider,
is_openrouter_image_model,
normalize_openrouter_models,
supports_image_input,
)
logger = logging.getLogger(__name__)
@ -292,24 +299,16 @@ def _generate_configs(
use_default: bool = settings.get("use_default_system_instructions", True)
citations_enabled: bool = settings.get("citations_enabled", True)
text_models = [
m
for m in raw_models
if _is_text_output_model(m)
and _supports_tool_calling(m)
and _has_sufficient_context(m)
and _is_compatible_provider(m)
and _is_allowed_model(m)
and "/" in m.get("id", "")
]
text_models = normalize_openrouter_models(raw_models)
configs: list[dict] = []
taken: set[int] = set()
now_ts = int(time.time())
for model in text_models:
model_id: str = model["id"]
name: str = model.get("name", model_id)
for normalized in text_models:
model = normalized.get("metadata") or {}
model_id: str = normalized["model_id"]
name: str = normalized.get("display_name") or model_id
tier = _openrouter_tier(model)
static_q = static_score_or(model, now_ts=now_ts)
@ -323,7 +322,7 @@ def _generate_configs(
"seo_enabled": seo_enabled,
"seo_slug": None,
"quota_reserve_tokens": quota_reserve_tokens,
"litellm_provider": "openrouter",
"provider": "openrouter",
"model_name": model_id,
"api_key": api_key,
"api_base": "https://openrouter.ai/api/v1",
@ -345,7 +344,7 @@ def _generate_configs(
# ``stream_new_chat`` as a fail-fast safety net before the
# OpenRouter request would otherwise 404 with
# ``"No endpoints found that support image input"``.
"supports_image_input": _supports_image_input(model),
"supports_image_input": bool(normalized.get("supports_image_input")),
_OPENROUTER_DYNAMIC_MARKER: True,
# Auto (Fastest) ranking metadata. ``quality_score`` is initialised
# to the static score and gets re-blended with health on the next
@ -403,10 +402,7 @@ def _generate_image_gen_configs(
image_models = [
m
for m in raw_models
if _is_image_output_model(m)
and _is_compatible_provider(m)
and _is_allowed_model(m)
and "/" in m.get("id", "")
if is_openrouter_image_model(m)
]
configs: list[dict] = []
@ -420,7 +416,7 @@ def _generate_image_gen_configs(
"id": _stable_config_id(model_id, id_offset, taken),
"name": name,
"description": f"{name} via OpenRouter (image generation)",
"litellm_provider": "openrouter",
"provider": "openrouter",
"model_name": model_id,
"api_key": api_key,
"api_base": "https://openrouter.ai/api/v1",
@ -468,9 +464,9 @@ def _generate_vision_llm_configs(
vision_models = [
m
for m in raw_models
if _is_vision_input_model(m)
and _is_compatible_provider(m)
and _is_allowed_model(m)
if supports_image_input(m)
and _shared_is_compatible_provider(m)
and _shared_is_allowed_model(m)
and "/" in m.get("id", "")
]
@ -499,7 +495,7 @@ def _generate_vision_llm_configs(
"id": _stable_config_id(model_id, id_offset, taken),
"name": name,
"description": f"{name} via OpenRouter (vision)",
"litellm_provider": "openrouter",
"provider": "openrouter",
"model_name": model_id,
"api_key": api_key,
"api_base": "https://openrouter.ai/api/v1",
@ -544,11 +540,9 @@ class OpenRouterIntegrationService:
# Cached raw catalogue from the most recent fetch. Image / vision
# emitters reuse this to avoid a second network call per surface.
self._raw_models: list[dict] = []
# Image / vision config caches (only populated when the matching
# opt-in flag is true on initialize). Refreshed in lockstep with
# the chat catalogue.
# Image config cache (only populated when the matching opt-in flag is
# true on initialize). Refreshed in lockstep with the chat catalogue.
self._image_configs: list[dict] = []
self._vision_configs: list[dict] = []
@classmethod
def get_instance(cls) -> "OpenRouterIntegrationService":
@ -583,7 +577,7 @@ class OpenRouterIntegrationService:
self._configs_by_id = {c["id"]: c for c in self._configs}
self._raw_pricing = _extract_raw_pricing(raw_models)
# Populate image / vision caches when their opt-in flag is set.
# Populate image cache when its opt-in flag is set.
# Empty otherwise so the accessors return [] without re-running
# filters every refresh.
if settings.get("image_generation_enabled"):
@ -595,15 +589,6 @@ class OpenRouterIntegrationService:
else:
self._image_configs = []
if settings.get("vision_enabled"):
self._vision_configs = _generate_vision_llm_configs(raw_models, settings)
logger.info(
"OpenRouter integration: vision LLM emission ON (%d models)",
len(self._vision_configs),
)
else:
self._vision_configs = []
self._initialized = True
tier_counts = self._tier_counts(self._configs)
@ -657,9 +642,9 @@ class OpenRouterIntegrationService:
self._configs = new_configs
self._configs_by_id = new_by_id
# Image / vision lists are atomic-swapped the same way: filter out
# Image list is atomic-swapped the same way: filter out
# the previous dynamic entries from the live config list and append
# the freshly generated ones. No-ops when the opt-in flag is off.
# the freshly generated ones. No-op when the opt-in flag is off.
if self._settings.get("image_generation_enabled"):
new_image = _generate_image_gen_configs(raw_models, self._settings)
static_image = [
@ -670,16 +655,6 @@ class OpenRouterIntegrationService:
app_config.GLOBAL_IMAGE_GEN_CONFIGS = static_image + new_image
self._image_configs = new_image
if self._settings.get("vision_enabled"):
new_vision = _generate_vision_llm_configs(raw_models, self._settings)
static_vision = [
c
for c in app_config.GLOBAL_VISION_LLM_CONFIGS
if not c.get(_OPENROUTER_DYNAMIC_MARKER)
]
app_config.GLOBAL_VISION_LLM_CONFIGS = static_vision + new_vision
self._vision_configs = new_vision
# Catalogue churn invalidates per-config "recently healthy" credit
# earned by the previous turn's preflight. Drop the whole table so
# the next turn re-probes against the freshly loaded configs.
@ -701,7 +676,7 @@ class OpenRouterIntegrationService:
)
# Re-blend health scores against the freshly fetched catalogue. Also
# re-stamps health for any YAML-curated cfg with litellm_provider=openrouter
# re-stamps health for any YAML-curated cfg with provider=openrouter
# so a hand-picked dead OR model is gated like a dynamic one.
await self._enrich_health_safely(static_configs + new_configs, log_summary=True)
@ -778,7 +753,7 @@ class OpenRouterIntegrationService:
the entire previous cycle's cache for this run.
"""
or_cfgs = [
c for c in configs if str(c.get("litellm_provider", "")).lower() == "openrouter"
c for c in configs if str(c.get("provider", "")).lower() == "openrouter"
]
if not or_cfgs:
return
@ -959,17 +934,6 @@ class OpenRouterIntegrationService:
"""
return list(self._image_configs)
def get_vision_llm_configs(self) -> list[dict]:
"""Return the dynamic OpenRouter vision-LLM configs (empty list
when the ``vision_enabled`` flag is off).
Each entry exposes ``input_cost_per_token`` / ``output_cost_per_token``
so ``pricing_registration`` can teach LiteLLM the cost of these
models the same way it does for chat which keeps the billable
wrapper able to debit accurate micro-USD on a vision call.
"""
return list(self._vision_configs)
def get_raw_pricing(self) -> dict[str, dict[str, str]]:
"""Return the cached raw OpenRouter pricing map.

View file

@ -0,0 +1,121 @@
"""Shared OpenRouter model normalization.
OpenRouter metadata is richer than generic OpenAI-compatible ``/models``
responses. Keep all OpenRouter filtering and capability extraction here so
GLOBAL catalogue generation and BYOK discovery agree.
"""
from __future__ import annotations
from typing import Any
from app.db import ModelSource
MIN_CONTEXT_LENGTH = 100_000
EXCLUDED_PROVIDER_SLUGS = {"amazon"}
EXCLUDED_MODEL_IDS: set[str] = {
"openai/gpt-4-1106-preview",
"openai/gpt-4-turbo-preview",
"openai/gpt-4o:extended",
"arcee-ai/virtuoso-large",
"openai/o3-deep-research",
"openai/o4-mini-deep-research",
"openrouter/free",
}
EXCLUDED_MODEL_SUFFIXES: tuple[str, ...] = ("-deep-research",)
def is_text_output_model(model: dict[str, Any]) -> bool:
output_mods = model.get("architecture", {}).get("output_modalities", [])
return output_mods == ["text"]
def is_image_output_model(model: dict[str, Any]) -> bool:
output_mods = model.get("architecture", {}).get("output_modalities", []) or []
return "image" in output_mods
def supports_image_input(model: dict[str, Any]) -> bool:
input_mods = model.get("architecture", {}).get("input_modalities", []) or []
return "image" in input_mods
def supports_tool_calling(model: dict[str, Any]) -> bool:
supported = model.get("supported_parameters") or []
return "tools" in supported
def has_sufficient_context(model: dict[str, Any]) -> bool:
return int(model.get("context_length") or 0) >= MIN_CONTEXT_LENGTH
def is_compatible_provider(model: dict[str, Any]) -> bool:
model_id = str(model.get("id") or "")
slug = model_id.split("/", 1)[0] if "/" in model_id else ""
return slug not in EXCLUDED_PROVIDER_SLUGS
def is_allowed_model(model: dict[str, Any]) -> bool:
model_id = str(model.get("id") or "")
if model_id in EXCLUDED_MODEL_IDS:
return False
base_id = model_id.split(":")[0]
return not base_id.endswith(EXCLUDED_MODEL_SUFFIXES)
def is_openrouter_chat_model(model: dict[str, Any]) -> bool:
return (
"/" in str(model.get("id") or "")
and is_text_output_model(model)
and supports_tool_calling(model)
and has_sufficient_context(model)
and is_compatible_provider(model)
and is_allowed_model(model)
)
def is_openrouter_image_model(model: dict[str, Any]) -> bool:
return (
"/" in str(model.get("id") or "")
and is_image_output_model(model)
and is_compatible_provider(model)
and is_allowed_model(model)
)
def normalize_openrouter_models(raw_models: list[dict[str, Any]]) -> list[dict[str, Any]]:
normalized: list[dict[str, Any]] = []
for model in raw_models:
if not is_openrouter_chat_model(model):
continue
model_id = str(model.get("id") or "")
normalized.append(
{
"model_id": model_id,
"display_name": model.get("name") or model_id,
"source": ModelSource.DISCOVERED,
"supports_chat": True,
"max_input_tokens": model.get("context_length"),
"supports_image_input": supports_image_input(model),
"supports_tools": supports_tool_calling(model),
"supports_image_generation": False,
"metadata": model,
}
)
return normalized
__all__ = [
"MIN_CONTEXT_LENGTH",
"has_sufficient_context",
"is_allowed_model",
"is_compatible_provider",
"is_image_output_model",
"is_openrouter_chat_model",
"is_openrouter_image_model",
"is_text_output_model",
"normalize_openrouter_models",
"supports_image_input",
"supports_tool_calling",
]

View file

@ -143,7 +143,7 @@ def _register_chat_shape_configs(
sample_keys: list[str] = []
for cfg in configs:
provider = str(cfg.get("litellm_provider") or "").lower()
provider = str(cfg.get("provider") or cfg.get("litellm_provider") or "").lower()
model_name = str(cfg.get("model_name") or "").strip()
litellm_params = cfg.get("litellm_params") or {}
base_model = str(litellm_params.get("base_model") or model_name).strip()
@ -216,9 +216,8 @@ def _register_chat_shape_configs(
def register_pricing_from_global_configs() -> None:
"""Register pricing for every known LLM deployment with LiteLLM.
Walks ``config.GLOBAL_LLM_CONFIGS`` *and* ``config.GLOBAL_VISION_LLM_CONFIGS``
so vision calls (during indexing) can resolve cost the same way chat
calls do namely:
Walks ``config.GLOBAL_LLM_CONFIGS`` so chat and vision calls can resolve
cost from the same chat-shaped deployment configs:
1. ``OPENROUTER``: pulls the cached raw pricing from
``OpenRouterIntegrationService`` (populated during its own
@ -245,10 +244,7 @@ def register_pricing_from_global_configs() -> None:
from app.config import config as app_config
chat_configs: list[dict] = list(getattr(app_config, "GLOBAL_LLM_CONFIGS", []) or [])
vision_configs: list[dict] = list(
getattr(app_config, "GLOBAL_VISION_LLM_CONFIGS", []) or []
)
if not chat_configs and not vision_configs:
if not chat_configs:
logger.info("[PricingRegistration] no global configs to register")
return
@ -267,7 +263,3 @@ def register_pricing_from_global_configs() -> None:
if chat_configs:
_register_chat_shape_configs(chat_configs, or_pricing=or_pricing, label="chat")
if vision_configs:
_register_chat_shape_configs(
vision_configs, or_pricing=or_pricing, label="vision"
)

View file

@ -51,7 +51,7 @@ logger = logging.getLogger(__name__)
def _candidate_model_strings(
*,
litellm_provider: str | None,
provider: str | None,
model_name: str | None,
base_model: str | None,
custom_provider: str | None,
@ -78,7 +78,7 @@ def _candidate_model_strings(
seen.add(key)
candidates.append(key)
provider_prefix = custom_provider or litellm_provider
provider_prefix = custom_provider or provider
primary_model = base_model or model_name
bare_model = model_name
@ -113,7 +113,7 @@ def _candidate_model_strings(
def derive_supports_image_input(
*,
litellm_provider: str | None = None,
provider: str | None = None,
model_name: str | None = None,
base_model: str | None = None,
custom_provider: str | None = None,
@ -147,7 +147,7 @@ def derive_supports_image_input(
return False
for model_string, custom_llm_provider in _candidate_model_strings(
litellm_provider=litellm_provider,
provider=provider,
model_name=model_name,
base_model=base_model,
custom_provider=custom_provider,
@ -172,7 +172,7 @@ def derive_supports_image_input(
def is_known_text_only_chat_model(
*,
litellm_provider: str | None = None,
provider: str | None = None,
model_name: str | None = None,
base_model: str | None = None,
custom_provider: str | None = None,
@ -193,7 +193,7 @@ def is_known_text_only_chat_model(
leads to the regression we're fixing here.
"""
for model_string, custom_llm_provider in _candidate_model_strings(
litellm_provider=litellm_provider,
provider=provider,
model_name=model_name,
base_model=base_model,
custom_provider=custom_provider,

View file

@ -0,0 +1,98 @@
"""Provider registry for model connections.
The provider string is the single public identity axis. This registry only
describes providers whose behavior differs from LiteLLM's native default.
"""
from __future__ import annotations
from dataclasses import dataclass
from enum import StrEnum
from typing import Literal
class Transport(StrEnum):
NATIVE = "NATIVE"
OPENAI_COMPATIBLE = "OPENAI_COMPATIBLE"
OLLAMA = "OLLAMA"
DiscoveryKind = Literal[
"ollama",
"openai_models",
"anthropic_models",
"openrouter",
"static",
"none",
]
AuthStyle = Literal["bearer", "x-api-key", "none", "native"]
@dataclass(frozen=True)
class ProviderSpec:
transport: Transport
litellm_prefix: str | None
discovery: DiscoveryKind
default_base_url: str | None
base_url_required: bool
auth_style: AuthStyle
REGISTRY: dict[str, ProviderSpec] = {
"openai": ProviderSpec(
Transport.NATIVE, "openai", "openai_models", None, False, "bearer"
),
"anthropic": ProviderSpec(
Transport.NATIVE, "anthropic", "anthropic_models", None, False, "x-api-key"
),
"azure": ProviderSpec(Transport.NATIVE, "azure", "static", None, True, "native"),
"vertex_ai": ProviderSpec(
Transport.NATIVE, "vertex_ai", "static", None, False, "native"
),
"bedrock": ProviderSpec(
Transport.NATIVE, "bedrock", "static", None, False, "native"
),
"openrouter": ProviderSpec(
Transport.OPENAI_COMPATIBLE,
"openrouter",
"openrouter",
"https://openrouter.ai/api/v1",
False,
"bearer",
),
"openai_compatible": ProviderSpec(
Transport.OPENAI_COMPATIBLE,
"openai",
"openai_models",
None,
True,
"bearer",
),
"lm_studio": ProviderSpec(
Transport.OPENAI_COMPATIBLE,
"openai",
"openai_models",
"http://localhost:1234/v1",
True,
"bearer",
),
"ollama_chat": ProviderSpec(
Transport.OLLAMA,
"ollama_chat",
"ollama",
"http://localhost:11434",
True,
"none",
),
}
def spec_for(provider: str | None) -> ProviderSpec:
provider_key = (provider or "").strip()
return REGISTRY.get(provider_key) or ProviderSpec(
Transport.NATIVE, provider_key or "openai", "static", None, False, "native"
)
__all__ = ["REGISTRY", "ProviderSpec", "Transport", "spec_for"]

View file

@ -273,7 +273,7 @@ def static_score_yaml(cfg: dict) -> int:
listed this model. Pricing / context fall through to lazy ``litellm``
lookups; failures are silent (we just lose those sub-points).
"""
provider = str(cfg.get("litellm_provider", "")).lower()
provider = str(cfg.get("provider") or cfg.get("litellm_provider") or "").lower()
base = PROVIDER_PRESTIGE_YAML.get(provider, 15)
model_name = cfg.get("model_name") or ""

View file

@ -40,7 +40,7 @@ def check_image_input_capability(
else None
)
if not is_known_text_only_chat_model(
litellm_provider=agent_config.provider,
provider=agent_config.provider,
model_name=agent_config.model_name,
base_model=agent_base_model,
custom_provider=agent_config.custom_provider,

View file

@ -22,6 +22,7 @@ from app.agents.chat.runtime.llm_config import (
)
from app.config import config
from app.db import Model, SearchSpace
from app.services.model_capabilities import has_capability
from app.services.model_resolver import to_litellm
@ -96,7 +97,7 @@ async def load_llm_bundle(
model_id=config_id,
search_space=search_space,
)
if not model or not (model.capabilities or {}).get("chat"):
if not model or not has_capability(model, "chat"):
return (
None,
None,
@ -106,12 +107,12 @@ async def load_llm_bundle(
agent_config = _agent_config_from_resolved(
config_id=config_id,
config_name=model.display_name or model.model_id,
provider=model.connection.litellm_provider or "",
provider=model.connection.provider or "",
model_name=model.model_id,
api_key=model.connection.api_key,
api_base=model.connection.base_url,
litellm_params=(model.connection.extra or {}).get("litellm_params"),
supports_image_input=bool((model.capabilities or {}).get("vision")),
supports_image_input=has_capability(model, "vision"),
billing_tier="free",
)
return (
@ -121,7 +122,7 @@ async def load_llm_bundle(
)
global_model = next((m for m in config.GLOBAL_MODELS if m.get("id") == config_id), None)
if not global_model or not (global_model.get("capabilities") or {}).get("chat"):
if not global_model or not has_capability(global_model, "chat"):
return None, None, f"Failed to load global chat model with id {config_id}"
global_connection = next(
(
@ -137,12 +138,12 @@ async def load_llm_bundle(
agent_config = _agent_config_from_resolved(
config_id=config_id,
config_name=global_model.get("display_name") or global_model.get("model_id"),
provider=global_connection.get("litellm_provider") or "",
provider=global_connection.get("provider") or "",
model_name=global_model["model_id"],
api_key=global_connection.get("api_key"),
api_base=global_connection.get("base_url"),
litellm_params=(global_connection.get("extra") or {}).get("litellm_params"),
supports_image_input=bool((global_model.get("capabilities") or {}).get("vision")),
supports_image_input=has_capability(global_model, "vision"),
billing_tier=str(global_model.get("billing_tier", "free")).lower(),
)
return (