refactor(model-connections): streamline global model config persistence

This commit is contained in:
Anish Sarkar 2026-06-11 18:20:53 +05:30
parent 3f01642199
commit c6a25cc1fe
13 changed files with 277 additions and 224 deletions

View file

@ -20,7 +20,7 @@ depends_on: str | Sequence[str] | None = None
connection_protocol = postgresql.ENUM(
"OLLAMA",
"OPENAI_COMPATIBLE",
"NATIVE",
"ANTHROPIC",
name="connectionprotocol",
create_type=False,
)
@ -39,122 +39,172 @@ model_source = postgresql.ENUM(
)
def _table_exists(table_name: str) -> bool:
return table_name in sa.inspect(op.get_bind()).get_table_names()
def _column_exists(table_name: str, column_name: str) -> bool:
if not _table_exists(table_name):
return False
return column_name in {
column["name"] for column in sa.inspect(op.get_bind()).get_columns(table_name)
}
def _index_exists(table_name: str, index_name: str) -> bool:
if not _table_exists(table_name):
return False
return index_name in {
index["name"] for index in sa.inspect(op.get_bind()).get_indexes(table_name)
}
def _create_index_if_missing(
index_name: str,
table_name: str,
columns: list[str],
) -> None:
if not _index_exists(table_name, index_name):
op.create_index(index_name, table_name, columns, unique=False)
def _add_searchspace_column_if_missing(column_name: str) -> None:
if not _column_exists("searchspaces", column_name):
op.add_column("searchspaces", sa.Column(column_name, sa.Integer(), nullable=True))
def upgrade() -> None:
bind = op.get_bind()
connection_protocol.create(bind, checkfirst=True)
op.execute("ALTER TYPE connectionprotocol ADD VALUE IF NOT EXISTS 'ANTHROPIC'")
connection_scope.create(bind, checkfirst=True)
model_source.create(bind, checkfirst=True)
op.create_table(
if _table_exists("connections"):
if _column_exists("connections", "native_provider") and not _column_exists(
"connections", "litellm_provider"
):
op.alter_column(
"connections",
"native_provider",
new_column_name="litellm_provider",
existing_type=sa.String(length=100),
existing_nullable=True,
)
elif not _column_exists("connections", "litellm_provider"):
op.add_column(
"connections",
sa.Column("litellm_provider", sa.String(length=100), nullable=True),
)
else:
op.create_table(
"connections",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("protocol", connection_protocol, nullable=False),
sa.Column("litellm_provider", sa.String(length=100), nullable=True),
sa.Column("base_url", sa.String(length=500), nullable=True),
sa.Column("api_key", sa.String(), nullable=True),
sa.Column(
"extra",
postgresql.JSONB(astext_type=sa.Text()),
server_default=sa.text("'{}'::jsonb"),
nullable=False,
),
sa.Column("scope", connection_scope, nullable=False),
sa.Column("enabled", sa.Boolean(), server_default=sa.text("true"), nullable=False),
sa.Column("search_space_id", sa.Integer(), nullable=True),
sa.Column("user_id", sa.UUID(), nullable=True),
sa.Column("last_verified_at", sa.TIMESTAMP(timezone=True), nullable=True),
sa.Column("last_status", sa.String(length=50), nullable=True),
sa.Column("last_error", sa.Text(), nullable=True),
sa.CheckConstraint(
"(scope = 'GLOBAL' AND search_space_id IS NULL AND user_id IS NULL) OR "
"(scope = 'SEARCH_SPACE' AND search_space_id IS NOT NULL AND user_id IS NOT NULL) OR "
"(scope = 'USER' AND user_id IS NOT NULL)",
name="ck_connections_scope_owner",
),
sa.ForeignKeyConstraint(
["search_space_id"], ["searchspaces.id"], ondelete="CASCADE"
),
sa.ForeignKeyConstraint(["user_id"], ["user.id"], ondelete="CASCADE"),
sa.PrimaryKeyConstraint("id"),
)
if _index_exists("connections", "ix_connections_native_provider") and not _index_exists(
"connections", "ix_connections_litellm_provider"
):
op.execute(
"ALTER INDEX ix_connections_native_provider "
"RENAME TO ix_connections_litellm_provider"
)
_create_index_if_missing("ix_connections_protocol", "connections", ["protocol"])
_create_index_if_missing(
"ix_connections_litellm_provider",
"connections",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("protocol", connection_protocol, nullable=False),
sa.Column("native_provider", sa.String(length=100), nullable=True),
sa.Column("base_url", sa.String(length=500), nullable=True),
sa.Column("api_key", sa.String(), nullable=True),
sa.Column(
"extra",
postgresql.JSONB(astext_type=sa.Text()),
server_default=sa.text("'{}'::jsonb"),
nullable=False,
),
sa.Column("scope", connection_scope, nullable=False),
sa.Column("enabled", sa.Boolean(), server_default=sa.text("true"), nullable=False),
sa.Column("search_space_id", sa.Integer(), nullable=True),
sa.Column("user_id", sa.UUID(), nullable=True),
sa.Column("last_verified_at", sa.TIMESTAMP(timezone=True), nullable=True),
sa.Column("last_status", sa.String(length=50), nullable=True),
sa.Column("last_error", sa.Text(), nullable=True),
sa.CheckConstraint(
"(scope = 'GLOBAL' AND search_space_id IS NULL AND user_id IS NULL) OR "
"(scope = 'SEARCH_SPACE' AND search_space_id IS NOT NULL AND user_id IS NOT NULL) OR "
"(scope = 'USER' AND user_id IS NOT NULL)",
name="ck_connections_scope_owner",
),
sa.ForeignKeyConstraint(
["search_space_id"], ["searchspaces.id"], ondelete="CASCADE"
),
sa.ForeignKeyConstraint(["user_id"], ["user.id"], ondelete="CASCADE"),
sa.PrimaryKeyConstraint("id"),
["litellm_provider"],
)
op.create_index(op.f("ix_connections_protocol"), "connections", ["protocol"], unique=False)
op.create_index(
op.f("ix_connections_native_provider"),
"connections",
["native_provider"],
unique=False,
)
op.create_index(op.f("ix_connections_scope"), "connections", ["scope"], unique=False)
_create_index_if_missing("ix_connections_scope", "connections", ["scope"])
op.create_table(
"models",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("connection_id", sa.Integer(), nullable=False),
sa.Column("model_id", sa.String(length=255), nullable=False),
sa.Column("display_name", sa.String(length=255), nullable=True),
sa.Column(
"source",
model_source,
server_default="DISCOVERED",
nullable=False,
),
sa.Column(
"capabilities",
postgresql.JSONB(astext_type=sa.Text()),
server_default=sa.text("'{}'::jsonb"),
nullable=False,
),
sa.Column(
"capabilities_declared",
postgresql.JSONB(astext_type=sa.Text()),
server_default=sa.text("'{}'::jsonb"),
nullable=False,
),
sa.Column(
"capabilities_verified",
postgresql.JSONB(astext_type=sa.Text()),
server_default=sa.text("'{}'::jsonb"),
nullable=False,
),
sa.Column(
"capabilities_override",
postgresql.JSONB(astext_type=sa.Text()),
server_default=sa.text("'{}'::jsonb"),
nullable=False,
),
sa.Column("embedding_dimension", sa.Integer(), nullable=True),
sa.Column("enabled", sa.Boolean(), server_default=sa.text("true"), nullable=False),
sa.Column("billing_tier", sa.String(length=50), nullable=True),
sa.Column(
"catalog",
postgresql.JSONB(astext_type=sa.Text()),
server_default=sa.text("'{}'::jsonb"),
nullable=False,
),
sa.ForeignKeyConstraint(["connection_id"], ["connections.id"], ondelete="CASCADE"),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint(
"connection_id", "model_id", name="uq_models_connection_model_id"
),
)
op.create_index(op.f("ix_models_connection_id"), "models", ["connection_id"], unique=False)
op.create_index("ix_models_model_id", "models", ["model_id"], unique=False)
op.create_index(op.f("ix_models_billing_tier"), "models", ["billing_tier"], unique=False)
if not _table_exists("models"):
op.create_table(
"models",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("connection_id", sa.Integer(), nullable=False),
sa.Column("model_id", sa.String(length=255), nullable=False),
sa.Column("display_name", sa.String(length=255), nullable=True),
sa.Column(
"source",
model_source,
server_default="DISCOVERED",
nullable=False,
),
sa.Column(
"capabilities",
postgresql.JSONB(astext_type=sa.Text()),
server_default=sa.text("'{}'::jsonb"),
nullable=False,
),
sa.Column(
"capabilities_declared",
postgresql.JSONB(astext_type=sa.Text()),
server_default=sa.text("'{}'::jsonb"),
nullable=False,
),
sa.Column(
"capabilities_verified",
postgresql.JSONB(astext_type=sa.Text()),
server_default=sa.text("'{}'::jsonb"),
nullable=False,
),
sa.Column(
"capabilities_override",
postgresql.JSONB(astext_type=sa.Text()),
server_default=sa.text("'{}'::jsonb"),
nullable=False,
),
sa.Column("embedding_dimension", sa.Integer(), nullable=True),
sa.Column("enabled", sa.Boolean(), server_default=sa.text("true"), nullable=False),
sa.Column("billing_tier", sa.String(length=50), nullable=True),
sa.Column(
"catalog",
postgresql.JSONB(astext_type=sa.Text()),
server_default=sa.text("'{}'::jsonb"),
nullable=False,
),
sa.ForeignKeyConstraint(["connection_id"], ["connections.id"], ondelete="CASCADE"),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint(
"connection_id", "model_id", name="uq_models_connection_model_id"
),
)
_create_index_if_missing("ix_models_connection_id", "models", ["connection_id"])
_create_index_if_missing("ix_models_model_id", "models", ["model_id"])
_create_index_if_missing("ix_models_billing_tier", "models", ["billing_tier"])
op.add_column(
"searchspaces",
sa.Column("chat_model_id", sa.Integer(), nullable=True),
)
op.add_column(
"searchspaces",
sa.Column("image_gen_model_id", sa.Integer(), nullable=True),
)
op.add_column(
"searchspaces",
sa.Column("vision_model_id", sa.Integer(), nullable=True),
)
_add_searchspace_column_if_missing("chat_model_id")
_add_searchspace_column_if_missing("image_gen_model_id")
_add_searchspace_column_if_missing("vision_model_id")
def downgrade() -> None:
@ -168,7 +218,7 @@ def downgrade() -> None:
op.drop_table("models")
op.drop_index(op.f("ix_connections_scope"), table_name="connections")
op.drop_index(op.f("ix_connections_native_provider"), table_name="connections")
op.drop_index(op.f("ix_connections_litellm_provider"), table_name="connections")
op.drop_index(op.f("ix_connections_protocol"), table_name="connections")
op.drop_table("connections")

View file

@ -78,8 +78,7 @@ def load_global_llm_configs():
# stamps) never leak into the cached YAML structure.
configs = copy.deepcopy(data.get("global_llm_configs", []))
# Lazy import keeps the `app.config` -> `app.services` edge one-way
# and matches the `provider_api_base` pattern used elsewhere.
# Lazy import keeps the `app.config` -> `app.services` edge one-way.
from app.services.provider_capabilities import derive_supports_image_input
seen_slugs: dict[str, int] = {}
@ -104,7 +103,7 @@ def load_global_llm_configs():
else None
)
cfg["supports_image_input"] = derive_supports_image_input(
provider=cfg.get("provider"),
litellm_provider=cfg.get("litellm_provider"),
model_name=cfg.get("model_name"),
base_model=base_model,
custom_provider=cfg.get("custom_provider"),
@ -123,7 +122,7 @@ def load_global_llm_configs():
# Stamp Auto (Fastest) ranking metadata. YAML configs are always
# Tier A — operator-curated, locked first when premium-eligible.
# The OpenRouter refresh tick later re-stamps health for any cfg
# whose provider == "OPENROUTER" via _enrich_health.
# whose litellm_provider == "openrouter" via _enrich_health.
try:
from app.services.quality_score import static_score_yaml
@ -133,7 +132,7 @@ def load_global_llm_configs():
cfg["quality_score_static"] = static_q
cfg["quality_score"] = static_q
cfg["quality_score_health"] = None
# YAML cfgs whose provider is OPENROUTER are also subject
# YAML cfgs whose litellm_provider is openrouter are also subject
# to health gating against their own /endpoints data — a
# hand-picked dead OR model is still dead. _enrich_health
# re-stamps health_gated for them on the next refresh tick.

View file

@ -18,7 +18,7 @@
# - Configure router_settings below to customize the load balancing behavior
#
# Static config shape:
# - Connection fields: provider, api_key, api_base, api_version
# - Connection fields: litellm_provider, api_key, api_base, api_version
# - Model fields: model_name, billing_tier, rpm/tpm, litellm_params
# - Prompt defaults: system_instructions, citations_enabled
# IDs share one GLOBAL model namespace across chat, vision, and image generation.
@ -75,10 +75,10 @@ global_llm_configs:
seo_enabled: true
seo_slug: "gpt-4-turbo"
quota_reserve_tokens: 4000
provider: "OPENAI"
litellm_provider: "openai"
model_name: "gpt-4-turbo-preview"
api_key: "sk-your-openai-api-key-here"
api_base: ""
api_base: "https://api.openai.com/v1"
# Rate limits for load balancing (requests/tokens per minute)
rpm: 500 # Requests per minute
tpm: 100000 # Tokens per minute
@ -99,10 +99,10 @@ global_llm_configs:
seo_enabled: true
seo_slug: "claude-3-opus"
quota_reserve_tokens: 4000
provider: "ANTHROPIC"
litellm_provider: "anthropic"
model_name: "claude-3-opus-20240229"
api_key: "sk-ant-your-anthropic-api-key-here"
api_base: ""
api_base: "https://api.anthropic.com/v1"
rpm: 1000
tpm: 100000
litellm_params:
@ -121,10 +121,10 @@ global_llm_configs:
seo_enabled: true
seo_slug: "gpt-3.5-turbo-fast"
quota_reserve_tokens: 2000
provider: "OPENAI"
litellm_provider: "openai"
model_name: "gpt-3.5-turbo"
api_key: "sk-your-openai-api-key-here"
api_base: ""
api_base: "https://api.openai.com/v1"
rpm: 3500 # GPT-3.5 has higher rate limits
tpm: 200000
litellm_params:
@ -143,7 +143,7 @@ global_llm_configs:
seo_enabled: true
seo_slug: "deepseek-chat-chinese"
quota_reserve_tokens: 4000
provider: "DEEPSEEK"
litellm_provider: "openai"
model_name: "deepseek-chat"
api_key: "your-deepseek-api-key-here"
api_base: "https://api.deepseek.com/v1"
@ -175,7 +175,7 @@ global_llm_configs:
seo_enabled: true
seo_slug: "azure-gpt-4o"
quota_reserve_tokens: 4000
provider: "AZURE"
litellm_provider: "azure"
# model_name format for Azure: azure/<your-deployment-name>
model_name: "azure/gpt-4o-deployment"
api_key: "your-azure-api-key-here"
@ -203,7 +203,7 @@ global_llm_configs:
seo_enabled: true
seo_slug: "azure-gpt-4-turbo"
quota_reserve_tokens: 4000
provider: "AZURE"
litellm_provider: "azure"
model_name: "azure/gpt-4-turbo-deployment"
api_key: "your-azure-api-key-here"
api_base: "https://your-resource.openai.azure.com"
@ -227,10 +227,10 @@ global_llm_configs:
seo_enabled: true
seo_slug: "groq-llama-3"
quota_reserve_tokens: 8000
provider: "GROQ"
litellm_provider: "groq"
model_name: "llama3-70b-8192"
api_key: "your-groq-api-key-here"
api_base: ""
api_base: "https://api.groq.com/openai/v1"
rpm: 30 # Groq has lower rate limits on free tier
tpm: 14400
litellm_params:
@ -249,7 +249,7 @@ global_llm_configs:
seo_enabled: true
seo_slug: "minimax-m3"
quota_reserve_tokens: 4000
provider: "MINIMAX"
litellm_provider: "openai"
model_name: "MiniMax-M3"
api_key: "your-minimax-api-key-here"
api_base: "https://api.minimax.io/v1"
@ -288,10 +288,10 @@ global_llm_configs:
anonymous_enabled: false
seo_enabled: false
quota_reserve_tokens: 1000
provider: "OPENAI"
litellm_provider: "openai"
model_name: "gpt-4o-mini"
api_key: "sk-your-openai-api-key-here"
api_base: ""
api_base: "https://api.openai.com/v1"
rpm: 3500
tpm: 200000
litellm_params:
@ -391,10 +391,10 @@ global_image_generation_configs:
- id: -2001
name: "Global DALL-E 3"
description: "OpenAI's DALL-E 3 for high-quality image generation"
provider: "OPENAI"
litellm_provider: "openai"
model_name: "dall-e-3"
api_key: "sk-your-openai-api-key-here"
api_base: ""
api_base: "https://api.openai.com/v1"
rpm: 50 # Requests per minute (image gen is rate-limited by RPM, not tokens)
litellm_params: {}
@ -402,10 +402,10 @@ global_image_generation_configs:
- id: -2002
name: "Global GPT Image 1"
description: "OpenAI's GPT Image 1 model"
provider: "OPENAI"
litellm_provider: "openai"
model_name: "gpt-image-1"
api_key: "sk-your-openai-api-key-here"
api_base: ""
api_base: "https://api.openai.com/v1"
rpm: 50
litellm_params: {}
@ -413,7 +413,7 @@ global_image_generation_configs:
- id: -2003
name: "Global Azure DALL-E 3"
description: "Azure-hosted DALL-E 3 deployment"
provider: "AZURE_OPENAI"
litellm_provider: "azure"
model_name: "azure/dall-e-3-deployment"
api_key: "your-azure-api-key-here"
api_base: "https://your-resource.openai.azure.com"
@ -426,10 +426,10 @@ global_image_generation_configs:
# - id: -2004
# name: "Global Gemini Image Gen"
# description: "Google Gemini image generation via OpenRouter"
# provider: "OPENROUTER"
# litellm_provider: "openrouter"
# model_name: "google/gemini-2.5-flash-image"
# api_key: "your-openrouter-api-key-here"
# api_base: ""
# api_base: "https://openrouter.ai/api/v1"
# rpm: 30
# litellm_params: {}
@ -455,10 +455,10 @@ global_vision_llm_configs:
- id: -1001
name: "Global GPT-4o Vision"
description: "OpenAI's GPT-4o with strong vision capabilities"
provider: "OPENAI"
litellm_provider: "openai"
model_name: "gpt-4o"
api_key: "sk-your-openai-api-key-here"
api_base: ""
api_base: "https://api.openai.com/v1"
rpm: 500
tpm: 100000
litellm_params:
@ -469,10 +469,10 @@ global_vision_llm_configs:
- id: -1002
name: "Global Gemini 2.0 Flash"
description: "Google's fast vision model with large context"
provider: "GOOGLE"
litellm_provider: "gemini"
model_name: "gemini-2.0-flash"
api_key: "your-google-ai-api-key-here"
api_base: ""
api_base: "https://generativelanguage.googleapis.com/v1beta"
rpm: 1000
tpm: 200000
litellm_params:
@ -483,10 +483,10 @@ global_vision_llm_configs:
- id: -1003
name: "Global Claude 3.5 Sonnet Vision"
description: "Anthropic's Claude 3.5 Sonnet with vision support"
provider: "ANTHROPIC"
litellm_provider: "anthropic"
model_name: "claude-3-5-sonnet-20241022"
api_key: "sk-ant-your-anthropic-api-key-here"
api_base: ""
api_base: "https://api.anthropic.com/v1"
rpm: 1000
tpm: 100000
litellm_params:
@ -497,7 +497,7 @@ global_vision_llm_configs:
# - id: -1004
# name: "Global Azure GPT-4o Vision"
# description: "Azure-hosted GPT-4o for vision analysis"
# provider: "AZURE_OPENAI"
# litellm_provider: "azure"
# model_name: "azure/gpt-4o-deployment"
# api_key: "your-azure-api-key-here"
# api_base: "https://your-resource.openai.azure.com"
@ -518,7 +518,7 @@ global_vision_llm_configs:
# - system_instructions: Custom prompt or empty string to use defaults
# - use_default_system_instructions: true = use SURFSENSE_SYSTEM_INSTRUCTIONS when system_instructions is empty
# - citations_enabled: true = include citation instructions, false = include anti-citation instructions
# - All standard LiteLLM providers are supported
# - All standard LiteLLM provider adapter names are supported
# - rpm/tpm: Optional rate limits for load balancing (requests/tokens per minute)
# These help the router distribute load evenly and avoid rate limit errors
#

View file

@ -283,7 +283,7 @@ class VisionProvider(StrEnum):
class ConnectionProtocol(StrEnum):
OLLAMA = "OLLAMA"
OPENAI_COMPATIBLE = "OPENAI_COMPATIBLE"
NATIVE = "NATIVE"
ANTHROPIC = "ANTHROPIC"
class ConnectionScope(StrEnum):
@ -1663,7 +1663,7 @@ class Connection(BaseModel, TimestampMixin):
__tablename__ = "connections"
protocol = Column(SQLAlchemyEnum(ConnectionProtocol), nullable=False, index=True)
native_provider = Column(String(100), nullable=True, index=True)
litellm_provider = Column(String(100), nullable=True, index=True)
base_url = Column(String(500), nullable=True)
api_key = Column(String, nullable=True)
extra = Column(JSONB, nullable=False, default=dict, server_default="{}")

View file

@ -8,6 +8,7 @@ from sqlalchemy.orm import selectinload
from app.config import config
from app.db import (
Connection,
ConnectionProtocol,
ConnectionScope,
Model,
ModelSource,
@ -40,6 +41,16 @@ router = APIRouter()
logger = logging.getLogger(__name__)
def _default_litellm_provider(protocol: ConnectionProtocol | str) -> str:
protocol_value = getattr(protocol, "value", str(protocol))
defaults = {
ConnectionProtocol.OLLAMA.value: "ollama_chat",
ConnectionProtocol.ANTHROPIC.value: "anthropic",
ConnectionProtocol.OPENAI_COMPATIBLE.value: "openai",
}
return defaults.get(protocol_value, "openai")
def _model_read(model: Model | dict) -> ModelRead:
return ModelRead.model_validate(model)
@ -58,7 +69,7 @@ def _connection_read(conn: Connection | dict, models: list[Model | dict] | None
return ConnectionRead(
id=conn.id,
protocol=conn.protocol,
native_provider=conn.native_provider,
litellm_provider=conn.litellm_provider,
base_url=conn.base_url,
extra=conn.extra or {},
scope=conn.scope,
@ -168,8 +179,12 @@ async def create_connection(
Permission.LLM_CONFIGS_CREATE.value,
"You don't have permission to create model connections in this search space",
)
payload = data.model_dump(exclude={"search_space_id"})
if not payload.get("litellm_provider"):
payload["litellm_provider"] = _default_litellm_provider(data.protocol)
conn = Connection(
**data.model_dump(exclude={"search_space_id"}),
**payload,
search_space_id=data.search_space_id if data.scope == ConnectionScope.SEARCH_SPACE else None,
user_id=user.id,
)

View file

@ -57,7 +57,7 @@ def _serialize_byok_config(config: NewLLMConfig) -> NewLLMConfigRead:
litellm_params.get("base_model") if isinstance(litellm_params, dict) else None
)
supports_image_input = derive_supports_image_input(
provider=provider_value,
litellm_provider=provider_value.lower(),
model_name=config.model_name,
base_model=base_model,
custom_provider=config.custom_provider,
@ -147,7 +147,7 @@ async def get_global_new_llm_configs(
else None
)
supports_image_input = derive_supports_image_input(
provider=cfg.get("provider"),
litellm_provider=cfg.get("litellm_provider"),
model_name=cfg.get("model_name"),
base_model=cfg_base_model,
custom_provider=cfg.get("custom_provider"),
@ -157,7 +157,7 @@ async def get_global_new_llm_configs(
"id": cfg.get("id"),
"name": cfg.get("name"),
"description": cfg.get("description"),
"provider": cfg.get("provider"),
"provider": cfg.get("litellm_provider"),
"custom_provider": cfg.get("custom_provider"),
"model_name": cfg.get("model_name"),
"api_base": cfg.get("api_base") or None,

View file

@ -419,7 +419,7 @@ async def _get_llm_config_by_id(
"id": cfg.get("id"),
"name": cfg.get("name"),
"description": cfg.get("description"),
"provider": cfg.get("provider"),
"provider": cfg.get("litellm_provider"),
"custom_provider": cfg.get("custom_provider"),
"model_name": cfg.get("model_name"),
"api_base": cfg.get("api_base"),
@ -490,7 +490,7 @@ async def _get_image_gen_config_by_id(
"id": cfg.get("id"),
"name": cfg.get("name"),
"description": cfg.get("description"),
"provider": cfg.get("provider"),
"provider": cfg.get("litellm_provider"),
"custom_provider": cfg.get("custom_provider"),
"model_name": cfg.get("model_name"),
"api_base": cfg.get("api_base") or None,
@ -550,7 +550,7 @@ async def _get_vision_llm_config_by_id(
"id": cfg.get("id"),
"name": cfg.get("name"),
"description": cfg.get("description"),
"provider": cfg.get("provider"),
"provider": cfg.get("litellm_provider"),
"custom_provider": cfg.get("custom_provider"),
"model_name": cfg.get("model_name"),
"api_base": cfg.get("api_base") or None,

View file

@ -29,7 +29,7 @@ class ModelRead(BaseModel):
class ConnectionRead(BaseModel):
id: int
protocol: ConnectionProtocol | str
native_provider: str | None = None
litellm_provider: str | None = None
base_url: str | None = None
extra: dict[str, Any] = Field(default_factory=dict)
scope: ConnectionScope | str
@ -48,7 +48,7 @@ class ConnectionRead(BaseModel):
class ConnectionCreate(BaseModel):
protocol: ConnectionProtocol
native_provider: str | None = None
litellm_provider: str | None = Field(None, max_length=100)
base_url: str | None = Field(None, max_length=500)
api_key: str | None = None
extra: dict[str, Any] = Field(default_factory=dict)
@ -58,7 +58,7 @@ class ConnectionCreate(BaseModel):
class ConnectionUpdate(BaseModel):
native_provider: str | None = None
litellm_provider: str | None = Field(None, max_length=100)
base_url: str | None = Field(None, max_length=500)
api_key: str | None = None
extra: dict[str, Any] | None = None

View file

@ -2,7 +2,6 @@
from __future__ import annotations
import asyncio
import contextlib
import logging
from dataclasses import dataclass
@ -13,8 +12,7 @@ import httpx
import litellm
from app.db import Connection, ConnectionProtocol, Model, ModelSource
from app.services.model_resolver import NATIVE_PROVIDER_PREFIX, ensure_v1, to_litellm
from app.services.provider_api_base import resolve_api_base
from app.services.model_resolver import ensure_v1, to_litellm
logger = logging.getLogger(__name__)
@ -36,6 +34,13 @@ def _auth_headers(conn: Connection) -> dict[str, str]:
return {"Authorization": f"Bearer {conn.api_key}"}
def _anthropic_headers(conn: Connection) -> dict[str, str]:
headers = {"anthropic-version": "2023-06-01"}
if conn.api_key:
headers["x-api-key"] = conn.api_key
return headers
def _docker_hint(url: str | None, exc_or_status: Any) -> str:
raw = str(exc_or_status)
if not url:
@ -56,24 +61,26 @@ def _docker_hint(url: str | None, exc_or_status: Any) -> str:
async def verify_connection(conn: Connection) -> VerifyResult:
if not conn.base_url and conn.protocol in (
ConnectionProtocol.OLLAMA,
ConnectionProtocol.OPENAI_COMPATIBLE,
):
if not conn.base_url:
return VerifyResult("UNREACHABLE", False, "Base URL is required.")
if conn.protocol == ConnectionProtocol.OLLAMA:
url = f"{conn.base_url.rstrip('/')}/api/version"
elif conn.protocol == ConnectionProtocol.OPENAI_COMPATIBLE:
url = f"{ensure_v1(conn.base_url)}/models"
elif conn.protocol == ConnectionProtocol.ANTHROPIC:
url = f"{conn.base_url.rstrip('/')}/models"
else:
# Native providers do not share one cheap health endpoint. The model
# probe exercises the real path and is the authoritative check.
return VerifyResult("OK", True, "Native provider configuration accepted.")
return VerifyResult("UNREACHABLE", False, "Unsupported connection protocol.")
try:
async with httpx.AsyncClient(timeout=VERIFY_TIMEOUT_SECONDS) as client:
response = await client.get(url, headers=_auth_headers(conn))
headers = (
_anthropic_headers(conn)
if conn.protocol == ConnectionProtocol.ANTHROPIC
else _auth_headers(conn)
)
response = await client.get(url, headers=headers)
if response.status_code in (401, 403):
return VerifyResult("AUTH_FAILED", False, "Authentication failed.")
if response.status_code == 404:
@ -156,39 +163,25 @@ async def _discover_openai_shaped_models(conn: Connection, base_url: str | None)
]
def _litellm_valid_model_ids(provider: str, api_key: str | None) -> list[str]:
if not api_key:
async def _discover_anthropic_models(conn: Connection) -> list[dict[str, Any]]:
if not conn.base_url:
return []
try:
models = litellm.get_valid_models(
check_provider_endpoint=True,
custom_llm_provider=provider,
api_key=api_key,
)
except Exception as exc:
logger.warning("LiteLLM model discovery failed for provider %s: %s", provider, exc)
return []
provider_prefix = f"{provider}/"
return [
model.removeprefix(provider_prefix)
for model in models
if isinstance(model, str) and model.strip()
]
async def _discover_litellm_native_models(conn: Connection, provider: str) -> list[dict[str, Any]]:
model_ids = await asyncio.to_thread(_litellm_valid_model_ids, provider, conn.api_key)
url = f"{conn.base_url.rstrip('/')}/models"
async with httpx.AsyncClient(timeout=DISCOVERY_TIMEOUT_SECONDS) as client:
response = await client.get(url, headers=_anthropic_headers(conn))
response.raise_for_status()
models = response.json().get("data", [])
return [
{
"model_id": model_id,
"display_name": model_id,
"model_id": item.get("id"),
"display_name": item.get("display_name") or item.get("id"),
"source": ModelSource.DISCOVERED,
"capabilities": derive_capabilities(conn, model_id),
"metadata": {},
"capabilities": derive_capabilities(conn, item.get("id"), item),
"metadata": item,
}
for model_id in model_ids
for item in models
if item.get("id")
]
@ -231,20 +224,10 @@ async def discover_models(conn: Connection) -> list[dict[str, Any]]:
]
elif conn.protocol == ConnectionProtocol.OPENAI_COMPATIBLE:
results = await _discover_openai_shaped_models(conn, conn.base_url)
elif conn.protocol == ConnectionProtocol.ANTHROPIC:
results = await _discover_anthropic_models(conn)
else:
provider_key = (conn.native_provider or "").upper()
provider = NATIVE_PROVIDER_PREFIX.get(provider_key, provider_key.lower())
api_base = resolve_api_base(
provider=provider_key,
provider_prefix=provider,
config_api_base=conn.base_url,
)
if api_base:
results = await _discover_openai_shaped_models(conn, api_base)
elif provider:
results = await _discover_litellm_native_models(conn, provider)
else:
results = []
results = []
if allowlist:
results = [item for item in results if item["model_id"] in allowlist]

View file

@ -19,7 +19,7 @@
# so the resolved auto-pin id is never sent to a real LLM provider.
# The values below only need to pass
# auto_model_pin_service._is_usable_global_config()
# which requires id / model_name / provider / api_key all truthy.
# which requires id / model_name / litellm_provider / api_key all truthy.
#
# Why TWO entries (premium + free):
# auto_model_pin_service.resolve_or_get_pinned_llm_config_id() splits
@ -44,9 +44,10 @@ global_llm_configs:
anonymous_enabled: false
seo_enabled: false
quality_score: 1.0
provider: "OPENAI"
litellm_provider: "openai"
model_name: "fake-e2e-model-premium"
api_key: "fake-e2e-api-key-not-for-production"
api_base: "https://api.openai.com/v1"
supports_image_input: false
quota_reserve_tokens: 1024
rpm: 1000
@ -60,9 +61,10 @@ global_llm_configs:
anonymous_enabled: false
seo_enabled: false
quality_score: 1.0
provider: "OPENAI"
litellm_provider: "openai"
model_name: "fake-e2e-model-free"
api_key: "fake-e2e-api-key-not-for-production"
api_base: "https://api.openai.com/v1"
supports_image_input: false
quota_reserve_tokens: 1024
rpm: 1000

View file

@ -25,7 +25,7 @@ _IMAGE_FIXTURE: list[dict] = [
{
"id": -1,
"name": "DALL-E 3",
"provider": "OPENAI",
"litellm_provider": "openai",
"model_name": "dall-e-3",
"api_key": "sk-test",
"billing_tier": "free",
@ -33,7 +33,7 @@ _IMAGE_FIXTURE: list[dict] = [
{
"id": -2,
"name": "GPT-Image 1 (premium)",
"provider": "OPENAI",
"litellm_provider": "openai",
"model_name": "gpt-image-1",
"api_key": "sk-test",
"billing_tier": "premium",
@ -41,7 +41,7 @@ _IMAGE_FIXTURE: list[dict] = [
{
"id": -20_001,
"name": "google/gemini-2.5-flash-image (OpenRouter)",
"provider": "OPENROUTER",
"litellm_provider": "openrouter",
"model_name": "google/gemini-2.5-flash-image",
"api_key": "sk-or-test",
"api_base": "https://openrouter.ai/api/v1",
@ -54,7 +54,7 @@ _VISION_FIXTURE: list[dict] = [
{
"id": -1,
"name": "GPT-4o Vision",
"provider": "OPENAI",
"litellm_provider": "openai",
"model_name": "gpt-4o",
"api_key": "sk-test",
"billing_tier": "free",
@ -62,7 +62,7 @@ _VISION_FIXTURE: list[dict] = [
{
"id": -2,
"name": "Claude 3.5 Sonnet (premium)",
"provider": "ANTHROPIC",
"litellm_provider": "anthropic",
"model_name": "claude-3-5-sonnet",
"api_key": "sk-ant-test",
"billing_tier": "premium",
@ -70,7 +70,7 @@ _VISION_FIXTURE: list[dict] = [
{
"id": -30_001,
"name": "openai/gpt-4o (OpenRouter)",
"provider": "OPENROUTER",
"litellm_provider": "openrouter",
"model_name": "openai/gpt-4o",
"api_key": "sk-or-test",
"api_base": "https://openrouter.ai/api/v1",

View file

@ -26,7 +26,7 @@ _FIXTURE: list[dict] = [
"id": -1,
"name": "GPT-4o (explicit true)",
"description": "vision-capable, explicit YAML override",
"provider": "OPENAI",
"litellm_provider": "openai",
"model_name": "gpt-4o",
"api_key": "sk-test",
"billing_tier": "free",
@ -36,7 +36,7 @@ _FIXTURE: list[dict] = [
"id": -2,
"name": "DeepSeek V3 (explicit false)",
"description": "OpenRouter dynamic — modality-derived false",
"provider": "OPENROUTER",
"litellm_provider": "openrouter",
"model_name": "deepseek/deepseek-v3.2-exp",
"api_key": "sk-or-test",
"api_base": "https://openrouter.ai/api/v1",
@ -47,7 +47,7 @@ _FIXTURE: list[dict] = [
"id": -10_010,
"name": "Unannotated GPT-4o",
"description": "no flag set — resolver should derive True via LiteLLM",
"provider": "OPENAI",
"litellm_provider": "openai",
"model_name": "gpt-4o",
"api_key": "sk-test",
"billing_tier": "free",
@ -57,7 +57,7 @@ _FIXTURE: list[dict] = [
"id": -10_011,
"name": "Unannotated unknown model",
"description": "unmapped — default-allow True",
"provider": "CUSTOM",
"litellm_provider": "custom",
"custom_provider": "brand_new_proxy",
"model_name": "brand-new-model-x9",
"api_key": "sk-test",

View file

@ -2,11 +2,12 @@ from app.services.global_model_catalog import materialize_global_model_catalog
from app.services.model_resolver import ensure_v1, to_litellm
def test_openai_compatible_resolver_normalizes_v1() -> None:
def test_openai_compatible_resolver_uses_explicit_api_base() -> None:
model, kwargs = to_litellm(
{
"protocol": "OPENAI_COMPATIBLE",
"base_url": "http://host.docker.internal:1234",
"litellm_provider": "openai",
"base_url": "http://host.docker.internal:1234/v1",
"api_key": "local-key",
"extra": {},
},
@ -23,6 +24,7 @@ def test_ollama_resolver_uses_native_api_base() -> None:
model, kwargs = to_litellm(
{
"protocol": "OLLAMA",
"litellm_provider": "ollama_chat",
"base_url": "http://host.docker.internal:11434",
"api_key": None,
"extra": {},
@ -40,9 +42,10 @@ def test_global_materialization_preserves_tier_and_keeps_key_server_side() -> No
{
"id": -101,
"name": "OpenRouter Free",
"provider": "OPENROUTER",
"litellm_provider": "openrouter",
"model_name": "meta-llama/llama-3.1-8b-instruct:free",
"api_key": "sk-global-secret",
"api_base": "https://openrouter.ai/api/v1",
"billing_tier": "free",
"anonymous_enabled": True,
"seo_enabled": True,
@ -52,9 +55,10 @@ def test_global_materialization_preserves_tier_and_keeps_key_server_side() -> No
{
"id": -102,
"name": "OpenRouter Premium",
"provider": "OPENROUTER",
"litellm_provider": "openrouter",
"model_name": "anthropic/claude-sonnet-4",
"api_key": "sk-global-secret",
"api_base": "https://openrouter.ai/api/v1",
"billing_tier": "premium",
},
],