From c6a25cc1fe1e128852146a09d2877d3496a2aaee Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Thu, 11 Jun 2026 18:20:53 +0530 Subject: [PATCH] refactor(model-connections): streamline global model config persistence --- .../versions/156_add_model_connections.py | 266 +++++++++++------- surfsense_backend/app/config/__init__.py | 9 +- .../app/config/global_llm_config.example.yaml | 60 ++-- surfsense_backend/app/db.py | 4 +- .../app/routes/model_connections_routes.py | 19 +- .../app/routes/new_llm_config_routes.py | 6 +- .../app/routes/search_spaces_routes.py | 6 +- .../app/schemas/model_connections.py | 6 +- .../app/services/model_connection_service.py | 85 +++--- .../tests/e2e/fixtures/global_llm_config.yaml | 8 +- .../routes/test_global_configs_is_premium.py | 12 +- ...t_global_new_llm_configs_supports_image.py | 8 +- .../unit/services/test_model_connections.py | 12 +- 13 files changed, 277 insertions(+), 224 deletions(-) diff --git a/surfsense_backend/alembic/versions/156_add_model_connections.py b/surfsense_backend/alembic/versions/156_add_model_connections.py index 0a11d7f9d..185debca4 100644 --- a/surfsense_backend/alembic/versions/156_add_model_connections.py +++ b/surfsense_backend/alembic/versions/156_add_model_connections.py @@ -20,7 +20,7 @@ depends_on: str | Sequence[str] | None = None connection_protocol = postgresql.ENUM( "OLLAMA", "OPENAI_COMPATIBLE", - "NATIVE", + "ANTHROPIC", name="connectionprotocol", create_type=False, ) @@ -39,122 +39,172 @@ model_source = postgresql.ENUM( ) +def _table_exists(table_name: str) -> bool: + return table_name in sa.inspect(op.get_bind()).get_table_names() + + +def _column_exists(table_name: str, column_name: str) -> bool: + if not _table_exists(table_name): + return False + return column_name in { + column["name"] for column in sa.inspect(op.get_bind()).get_columns(table_name) + } + + +def _index_exists(table_name: str, index_name: str) -> bool: + if not _table_exists(table_name): + return False + return index_name in { + index["name"] for index in sa.inspect(op.get_bind()).get_indexes(table_name) + } + + +def _create_index_if_missing( + index_name: str, + table_name: str, + columns: list[str], +) -> None: + if not _index_exists(table_name, index_name): + op.create_index(index_name, table_name, columns, unique=False) + + +def _add_searchspace_column_if_missing(column_name: str) -> None: + if not _column_exists("searchspaces", column_name): + op.add_column("searchspaces", sa.Column(column_name, sa.Integer(), nullable=True)) + + def upgrade() -> None: bind = op.get_bind() connection_protocol.create(bind, checkfirst=True) + op.execute("ALTER TYPE connectionprotocol ADD VALUE IF NOT EXISTS 'ANTHROPIC'") connection_scope.create(bind, checkfirst=True) model_source.create(bind, checkfirst=True) - op.create_table( + if _table_exists("connections"): + if _column_exists("connections", "native_provider") and not _column_exists( + "connections", "litellm_provider" + ): + op.alter_column( + "connections", + "native_provider", + new_column_name="litellm_provider", + existing_type=sa.String(length=100), + existing_nullable=True, + ) + elif not _column_exists("connections", "litellm_provider"): + op.add_column( + "connections", + sa.Column("litellm_provider", sa.String(length=100), nullable=True), + ) + else: + op.create_table( + "connections", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("created_at", sa.DateTime(timezone=True), nullable=False), + sa.Column("protocol", connection_protocol, nullable=False), + sa.Column("litellm_provider", sa.String(length=100), nullable=True), + sa.Column("base_url", sa.String(length=500), nullable=True), + sa.Column("api_key", sa.String(), nullable=True), + sa.Column( + "extra", + postgresql.JSONB(astext_type=sa.Text()), + server_default=sa.text("'{}'::jsonb"), + nullable=False, + ), + sa.Column("scope", connection_scope, nullable=False), + sa.Column("enabled", sa.Boolean(), server_default=sa.text("true"), nullable=False), + sa.Column("search_space_id", sa.Integer(), nullable=True), + sa.Column("user_id", sa.UUID(), nullable=True), + sa.Column("last_verified_at", sa.TIMESTAMP(timezone=True), nullable=True), + sa.Column("last_status", sa.String(length=50), nullable=True), + sa.Column("last_error", sa.Text(), nullable=True), + sa.CheckConstraint( + "(scope = 'GLOBAL' AND search_space_id IS NULL AND user_id IS NULL) OR " + "(scope = 'SEARCH_SPACE' AND search_space_id IS NOT NULL AND user_id IS NOT NULL) OR " + "(scope = 'USER' AND user_id IS NOT NULL)", + name="ck_connections_scope_owner", + ), + sa.ForeignKeyConstraint( + ["search_space_id"], ["searchspaces.id"], ondelete="CASCADE" + ), + sa.ForeignKeyConstraint(["user_id"], ["user.id"], ondelete="CASCADE"), + sa.PrimaryKeyConstraint("id"), + ) + if _index_exists("connections", "ix_connections_native_provider") and not _index_exists( + "connections", "ix_connections_litellm_provider" + ): + op.execute( + "ALTER INDEX ix_connections_native_provider " + "RENAME TO ix_connections_litellm_provider" + ) + _create_index_if_missing("ix_connections_protocol", "connections", ["protocol"]) + _create_index_if_missing( + "ix_connections_litellm_provider", "connections", - sa.Column("id", sa.Integer(), nullable=False), - sa.Column("created_at", sa.DateTime(timezone=True), nullable=False), - sa.Column("protocol", connection_protocol, nullable=False), - sa.Column("native_provider", sa.String(length=100), nullable=True), - sa.Column("base_url", sa.String(length=500), nullable=True), - sa.Column("api_key", sa.String(), nullable=True), - sa.Column( - "extra", - postgresql.JSONB(astext_type=sa.Text()), - server_default=sa.text("'{}'::jsonb"), - nullable=False, - ), - sa.Column("scope", connection_scope, nullable=False), - sa.Column("enabled", sa.Boolean(), server_default=sa.text("true"), nullable=False), - sa.Column("search_space_id", sa.Integer(), nullable=True), - sa.Column("user_id", sa.UUID(), nullable=True), - sa.Column("last_verified_at", sa.TIMESTAMP(timezone=True), nullable=True), - sa.Column("last_status", sa.String(length=50), nullable=True), - sa.Column("last_error", sa.Text(), nullable=True), - sa.CheckConstraint( - "(scope = 'GLOBAL' AND search_space_id IS NULL AND user_id IS NULL) OR " - "(scope = 'SEARCH_SPACE' AND search_space_id IS NOT NULL AND user_id IS NOT NULL) OR " - "(scope = 'USER' AND user_id IS NOT NULL)", - name="ck_connections_scope_owner", - ), - sa.ForeignKeyConstraint( - ["search_space_id"], ["searchspaces.id"], ondelete="CASCADE" - ), - sa.ForeignKeyConstraint(["user_id"], ["user.id"], ondelete="CASCADE"), - sa.PrimaryKeyConstraint("id"), + ["litellm_provider"], ) - op.create_index(op.f("ix_connections_protocol"), "connections", ["protocol"], unique=False) - op.create_index( - op.f("ix_connections_native_provider"), - "connections", - ["native_provider"], - unique=False, - ) - op.create_index(op.f("ix_connections_scope"), "connections", ["scope"], unique=False) + _create_index_if_missing("ix_connections_scope", "connections", ["scope"]) - op.create_table( - "models", - sa.Column("id", sa.Integer(), nullable=False), - sa.Column("created_at", sa.DateTime(timezone=True), nullable=False), - sa.Column("connection_id", sa.Integer(), nullable=False), - sa.Column("model_id", sa.String(length=255), nullable=False), - sa.Column("display_name", sa.String(length=255), nullable=True), - sa.Column( - "source", - model_source, - server_default="DISCOVERED", - nullable=False, - ), - sa.Column( - "capabilities", - postgresql.JSONB(astext_type=sa.Text()), - server_default=sa.text("'{}'::jsonb"), - nullable=False, - ), - sa.Column( - "capabilities_declared", - postgresql.JSONB(astext_type=sa.Text()), - server_default=sa.text("'{}'::jsonb"), - nullable=False, - ), - sa.Column( - "capabilities_verified", - postgresql.JSONB(astext_type=sa.Text()), - server_default=sa.text("'{}'::jsonb"), - nullable=False, - ), - sa.Column( - "capabilities_override", - postgresql.JSONB(astext_type=sa.Text()), - server_default=sa.text("'{}'::jsonb"), - nullable=False, - ), - sa.Column("embedding_dimension", sa.Integer(), nullable=True), - sa.Column("enabled", sa.Boolean(), server_default=sa.text("true"), nullable=False), - sa.Column("billing_tier", sa.String(length=50), nullable=True), - sa.Column( - "catalog", - postgresql.JSONB(astext_type=sa.Text()), - server_default=sa.text("'{}'::jsonb"), - nullable=False, - ), - sa.ForeignKeyConstraint(["connection_id"], ["connections.id"], ondelete="CASCADE"), - sa.PrimaryKeyConstraint("id"), - sa.UniqueConstraint( - "connection_id", "model_id", name="uq_models_connection_model_id" - ), - ) - op.create_index(op.f("ix_models_connection_id"), "models", ["connection_id"], unique=False) - op.create_index("ix_models_model_id", "models", ["model_id"], unique=False) - op.create_index(op.f("ix_models_billing_tier"), "models", ["billing_tier"], unique=False) + if not _table_exists("models"): + op.create_table( + "models", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("created_at", sa.DateTime(timezone=True), nullable=False), + sa.Column("connection_id", sa.Integer(), nullable=False), + sa.Column("model_id", sa.String(length=255), nullable=False), + sa.Column("display_name", sa.String(length=255), nullable=True), + sa.Column( + "source", + model_source, + server_default="DISCOVERED", + nullable=False, + ), + sa.Column( + "capabilities", + postgresql.JSONB(astext_type=sa.Text()), + server_default=sa.text("'{}'::jsonb"), + nullable=False, + ), + sa.Column( + "capabilities_declared", + postgresql.JSONB(astext_type=sa.Text()), + server_default=sa.text("'{}'::jsonb"), + nullable=False, + ), + sa.Column( + "capabilities_verified", + postgresql.JSONB(astext_type=sa.Text()), + server_default=sa.text("'{}'::jsonb"), + nullable=False, + ), + sa.Column( + "capabilities_override", + postgresql.JSONB(astext_type=sa.Text()), + server_default=sa.text("'{}'::jsonb"), + nullable=False, + ), + sa.Column("embedding_dimension", sa.Integer(), nullable=True), + sa.Column("enabled", sa.Boolean(), server_default=sa.text("true"), nullable=False), + sa.Column("billing_tier", sa.String(length=50), nullable=True), + sa.Column( + "catalog", + postgresql.JSONB(astext_type=sa.Text()), + server_default=sa.text("'{}'::jsonb"), + nullable=False, + ), + sa.ForeignKeyConstraint(["connection_id"], ["connections.id"], ondelete="CASCADE"), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint( + "connection_id", "model_id", name="uq_models_connection_model_id" + ), + ) + _create_index_if_missing("ix_models_connection_id", "models", ["connection_id"]) + _create_index_if_missing("ix_models_model_id", "models", ["model_id"]) + _create_index_if_missing("ix_models_billing_tier", "models", ["billing_tier"]) - op.add_column( - "searchspaces", - sa.Column("chat_model_id", sa.Integer(), nullable=True), - ) - op.add_column( - "searchspaces", - sa.Column("image_gen_model_id", sa.Integer(), nullable=True), - ) - op.add_column( - "searchspaces", - sa.Column("vision_model_id", sa.Integer(), nullable=True), - ) + _add_searchspace_column_if_missing("chat_model_id") + _add_searchspace_column_if_missing("image_gen_model_id") + _add_searchspace_column_if_missing("vision_model_id") def downgrade() -> None: @@ -168,7 +218,7 @@ def downgrade() -> None: op.drop_table("models") op.drop_index(op.f("ix_connections_scope"), table_name="connections") - op.drop_index(op.f("ix_connections_native_provider"), table_name="connections") + op.drop_index(op.f("ix_connections_litellm_provider"), table_name="connections") op.drop_index(op.f("ix_connections_protocol"), table_name="connections") op.drop_table("connections") diff --git a/surfsense_backend/app/config/__init__.py b/surfsense_backend/app/config/__init__.py index b8addb45d..fd8b29116 100644 --- a/surfsense_backend/app/config/__init__.py +++ b/surfsense_backend/app/config/__init__.py @@ -78,8 +78,7 @@ def load_global_llm_configs(): # stamps) never leak into the cached YAML structure. configs = copy.deepcopy(data.get("global_llm_configs", [])) - # Lazy import keeps the `app.config` -> `app.services` edge one-way - # and matches the `provider_api_base` pattern used elsewhere. + # Lazy import keeps the `app.config` -> `app.services` edge one-way. from app.services.provider_capabilities import derive_supports_image_input seen_slugs: dict[str, int] = {} @@ -104,7 +103,7 @@ def load_global_llm_configs(): else None ) cfg["supports_image_input"] = derive_supports_image_input( - provider=cfg.get("provider"), + litellm_provider=cfg.get("litellm_provider"), model_name=cfg.get("model_name"), base_model=base_model, custom_provider=cfg.get("custom_provider"), @@ -123,7 +122,7 @@ def load_global_llm_configs(): # Stamp Auto (Fastest) ranking metadata. YAML configs are always # Tier A — operator-curated, locked first when premium-eligible. # The OpenRouter refresh tick later re-stamps health for any cfg - # whose provider == "OPENROUTER" via _enrich_health. + # whose litellm_provider == "openrouter" via _enrich_health. try: from app.services.quality_score import static_score_yaml @@ -133,7 +132,7 @@ def load_global_llm_configs(): cfg["quality_score_static"] = static_q cfg["quality_score"] = static_q cfg["quality_score_health"] = None - # YAML cfgs whose provider is OPENROUTER are also subject + # YAML cfgs whose litellm_provider is openrouter are also subject # to health gating against their own /endpoints data — a # hand-picked dead OR model is still dead. _enrich_health # re-stamps health_gated for them on the next refresh tick. diff --git a/surfsense_backend/app/config/global_llm_config.example.yaml b/surfsense_backend/app/config/global_llm_config.example.yaml index b0eee6458..06676511f 100644 --- a/surfsense_backend/app/config/global_llm_config.example.yaml +++ b/surfsense_backend/app/config/global_llm_config.example.yaml @@ -18,7 +18,7 @@ # - Configure router_settings below to customize the load balancing behavior # # Static config shape: -# - Connection fields: provider, api_key, api_base, api_version +# - Connection fields: litellm_provider, api_key, api_base, api_version # - Model fields: model_name, billing_tier, rpm/tpm, litellm_params # - Prompt defaults: system_instructions, citations_enabled # IDs share one GLOBAL model namespace across chat, vision, and image generation. @@ -75,10 +75,10 @@ global_llm_configs: seo_enabled: true seo_slug: "gpt-4-turbo" quota_reserve_tokens: 4000 - provider: "OPENAI" + litellm_provider: "openai" model_name: "gpt-4-turbo-preview" api_key: "sk-your-openai-api-key-here" - api_base: "" + api_base: "https://api.openai.com/v1" # Rate limits for load balancing (requests/tokens per minute) rpm: 500 # Requests per minute tpm: 100000 # Tokens per minute @@ -99,10 +99,10 @@ global_llm_configs: seo_enabled: true seo_slug: "claude-3-opus" quota_reserve_tokens: 4000 - provider: "ANTHROPIC" + litellm_provider: "anthropic" model_name: "claude-3-opus-20240229" api_key: "sk-ant-your-anthropic-api-key-here" - api_base: "" + api_base: "https://api.anthropic.com/v1" rpm: 1000 tpm: 100000 litellm_params: @@ -121,10 +121,10 @@ global_llm_configs: seo_enabled: true seo_slug: "gpt-3.5-turbo-fast" quota_reserve_tokens: 2000 - provider: "OPENAI" + litellm_provider: "openai" model_name: "gpt-3.5-turbo" api_key: "sk-your-openai-api-key-here" - api_base: "" + api_base: "https://api.openai.com/v1" rpm: 3500 # GPT-3.5 has higher rate limits tpm: 200000 litellm_params: @@ -143,7 +143,7 @@ global_llm_configs: seo_enabled: true seo_slug: "deepseek-chat-chinese" quota_reserve_tokens: 4000 - provider: "DEEPSEEK" + litellm_provider: "openai" model_name: "deepseek-chat" api_key: "your-deepseek-api-key-here" api_base: "https://api.deepseek.com/v1" @@ -175,7 +175,7 @@ global_llm_configs: seo_enabled: true seo_slug: "azure-gpt-4o" quota_reserve_tokens: 4000 - provider: "AZURE" + litellm_provider: "azure" # model_name format for Azure: azure/ model_name: "azure/gpt-4o-deployment" api_key: "your-azure-api-key-here" @@ -203,7 +203,7 @@ global_llm_configs: seo_enabled: true seo_slug: "azure-gpt-4-turbo" quota_reserve_tokens: 4000 - provider: "AZURE" + litellm_provider: "azure" model_name: "azure/gpt-4-turbo-deployment" api_key: "your-azure-api-key-here" api_base: "https://your-resource.openai.azure.com" @@ -227,10 +227,10 @@ global_llm_configs: seo_enabled: true seo_slug: "groq-llama-3" quota_reserve_tokens: 8000 - provider: "GROQ" + litellm_provider: "groq" model_name: "llama3-70b-8192" api_key: "your-groq-api-key-here" - api_base: "" + api_base: "https://api.groq.com/openai/v1" rpm: 30 # Groq has lower rate limits on free tier tpm: 14400 litellm_params: @@ -249,7 +249,7 @@ global_llm_configs: seo_enabled: true seo_slug: "minimax-m3" quota_reserve_tokens: 4000 - provider: "MINIMAX" + litellm_provider: "openai" model_name: "MiniMax-M3" api_key: "your-minimax-api-key-here" api_base: "https://api.minimax.io/v1" @@ -288,10 +288,10 @@ global_llm_configs: anonymous_enabled: false seo_enabled: false quota_reserve_tokens: 1000 - provider: "OPENAI" + litellm_provider: "openai" model_name: "gpt-4o-mini" api_key: "sk-your-openai-api-key-here" - api_base: "" + api_base: "https://api.openai.com/v1" rpm: 3500 tpm: 200000 litellm_params: @@ -391,10 +391,10 @@ global_image_generation_configs: - id: -2001 name: "Global DALL-E 3" description: "OpenAI's DALL-E 3 for high-quality image generation" - provider: "OPENAI" + litellm_provider: "openai" model_name: "dall-e-3" api_key: "sk-your-openai-api-key-here" - api_base: "" + api_base: "https://api.openai.com/v1" rpm: 50 # Requests per minute (image gen is rate-limited by RPM, not tokens) litellm_params: {} @@ -402,10 +402,10 @@ global_image_generation_configs: - id: -2002 name: "Global GPT Image 1" description: "OpenAI's GPT Image 1 model" - provider: "OPENAI" + litellm_provider: "openai" model_name: "gpt-image-1" api_key: "sk-your-openai-api-key-here" - api_base: "" + api_base: "https://api.openai.com/v1" rpm: 50 litellm_params: {} @@ -413,7 +413,7 @@ global_image_generation_configs: - id: -2003 name: "Global Azure DALL-E 3" description: "Azure-hosted DALL-E 3 deployment" - provider: "AZURE_OPENAI" + litellm_provider: "azure" model_name: "azure/dall-e-3-deployment" api_key: "your-azure-api-key-here" api_base: "https://your-resource.openai.azure.com" @@ -426,10 +426,10 @@ global_image_generation_configs: # - id: -2004 # name: "Global Gemini Image Gen" # description: "Google Gemini image generation via OpenRouter" - # provider: "OPENROUTER" + # litellm_provider: "openrouter" # model_name: "google/gemini-2.5-flash-image" # api_key: "your-openrouter-api-key-here" - # api_base: "" + # api_base: "https://openrouter.ai/api/v1" # rpm: 30 # litellm_params: {} @@ -455,10 +455,10 @@ global_vision_llm_configs: - id: -1001 name: "Global GPT-4o Vision" description: "OpenAI's GPT-4o with strong vision capabilities" - provider: "OPENAI" + litellm_provider: "openai" model_name: "gpt-4o" api_key: "sk-your-openai-api-key-here" - api_base: "" + api_base: "https://api.openai.com/v1" rpm: 500 tpm: 100000 litellm_params: @@ -469,10 +469,10 @@ global_vision_llm_configs: - id: -1002 name: "Global Gemini 2.0 Flash" description: "Google's fast vision model with large context" - provider: "GOOGLE" + litellm_provider: "gemini" model_name: "gemini-2.0-flash" api_key: "your-google-ai-api-key-here" - api_base: "" + api_base: "https://generativelanguage.googleapis.com/v1beta" rpm: 1000 tpm: 200000 litellm_params: @@ -483,10 +483,10 @@ global_vision_llm_configs: - id: -1003 name: "Global Claude 3.5 Sonnet Vision" description: "Anthropic's Claude 3.5 Sonnet with vision support" - provider: "ANTHROPIC" + litellm_provider: "anthropic" model_name: "claude-3-5-sonnet-20241022" api_key: "sk-ant-your-anthropic-api-key-here" - api_base: "" + api_base: "https://api.anthropic.com/v1" rpm: 1000 tpm: 100000 litellm_params: @@ -497,7 +497,7 @@ global_vision_llm_configs: # - id: -1004 # name: "Global Azure GPT-4o Vision" # description: "Azure-hosted GPT-4o for vision analysis" - # provider: "AZURE_OPENAI" + # litellm_provider: "azure" # model_name: "azure/gpt-4o-deployment" # api_key: "your-azure-api-key-here" # api_base: "https://your-resource.openai.azure.com" @@ -518,7 +518,7 @@ global_vision_llm_configs: # - system_instructions: Custom prompt or empty string to use defaults # - use_default_system_instructions: true = use SURFSENSE_SYSTEM_INSTRUCTIONS when system_instructions is empty # - citations_enabled: true = include citation instructions, false = include anti-citation instructions -# - All standard LiteLLM providers are supported +# - All standard LiteLLM provider adapter names are supported # - rpm/tpm: Optional rate limits for load balancing (requests/tokens per minute) # These help the router distribute load evenly and avoid rate limit errors # diff --git a/surfsense_backend/app/db.py b/surfsense_backend/app/db.py index 9756cb32f..4c628b05a 100644 --- a/surfsense_backend/app/db.py +++ b/surfsense_backend/app/db.py @@ -283,7 +283,7 @@ class VisionProvider(StrEnum): class ConnectionProtocol(StrEnum): OLLAMA = "OLLAMA" OPENAI_COMPATIBLE = "OPENAI_COMPATIBLE" - NATIVE = "NATIVE" + ANTHROPIC = "ANTHROPIC" class ConnectionScope(StrEnum): @@ -1663,7 +1663,7 @@ class Connection(BaseModel, TimestampMixin): __tablename__ = "connections" protocol = Column(SQLAlchemyEnum(ConnectionProtocol), nullable=False, index=True) - native_provider = Column(String(100), nullable=True, index=True) + litellm_provider = Column(String(100), nullable=True, index=True) base_url = Column(String(500), nullable=True) api_key = Column(String, nullable=True) extra = Column(JSONB, nullable=False, default=dict, server_default="{}") diff --git a/surfsense_backend/app/routes/model_connections_routes.py b/surfsense_backend/app/routes/model_connections_routes.py index 5872671b1..cae951c3a 100644 --- a/surfsense_backend/app/routes/model_connections_routes.py +++ b/surfsense_backend/app/routes/model_connections_routes.py @@ -8,6 +8,7 @@ from sqlalchemy.orm import selectinload from app.config import config from app.db import ( Connection, + ConnectionProtocol, ConnectionScope, Model, ModelSource, @@ -40,6 +41,16 @@ router = APIRouter() logger = logging.getLogger(__name__) +def _default_litellm_provider(protocol: ConnectionProtocol | str) -> str: + protocol_value = getattr(protocol, "value", str(protocol)) + defaults = { + ConnectionProtocol.OLLAMA.value: "ollama_chat", + ConnectionProtocol.ANTHROPIC.value: "anthropic", + ConnectionProtocol.OPENAI_COMPATIBLE.value: "openai", + } + return defaults.get(protocol_value, "openai") + + def _model_read(model: Model | dict) -> ModelRead: return ModelRead.model_validate(model) @@ -58,7 +69,7 @@ def _connection_read(conn: Connection | dict, models: list[Model | dict] | None return ConnectionRead( id=conn.id, protocol=conn.protocol, - native_provider=conn.native_provider, + litellm_provider=conn.litellm_provider, base_url=conn.base_url, extra=conn.extra or {}, scope=conn.scope, @@ -168,8 +179,12 @@ async def create_connection( Permission.LLM_CONFIGS_CREATE.value, "You don't have permission to create model connections in this search space", ) + payload = data.model_dump(exclude={"search_space_id"}) + if not payload.get("litellm_provider"): + payload["litellm_provider"] = _default_litellm_provider(data.protocol) + conn = Connection( - **data.model_dump(exclude={"search_space_id"}), + **payload, search_space_id=data.search_space_id if data.scope == ConnectionScope.SEARCH_SPACE else None, user_id=user.id, ) diff --git a/surfsense_backend/app/routes/new_llm_config_routes.py b/surfsense_backend/app/routes/new_llm_config_routes.py index 84d66bb13..531fa8730 100644 --- a/surfsense_backend/app/routes/new_llm_config_routes.py +++ b/surfsense_backend/app/routes/new_llm_config_routes.py @@ -57,7 +57,7 @@ def _serialize_byok_config(config: NewLLMConfig) -> NewLLMConfigRead: litellm_params.get("base_model") if isinstance(litellm_params, dict) else None ) supports_image_input = derive_supports_image_input( - provider=provider_value, + litellm_provider=provider_value.lower(), model_name=config.model_name, base_model=base_model, custom_provider=config.custom_provider, @@ -147,7 +147,7 @@ async def get_global_new_llm_configs( else None ) supports_image_input = derive_supports_image_input( - provider=cfg.get("provider"), + litellm_provider=cfg.get("litellm_provider"), model_name=cfg.get("model_name"), base_model=cfg_base_model, custom_provider=cfg.get("custom_provider"), @@ -157,7 +157,7 @@ async def get_global_new_llm_configs( "id": cfg.get("id"), "name": cfg.get("name"), "description": cfg.get("description"), - "provider": cfg.get("provider"), + "provider": cfg.get("litellm_provider"), "custom_provider": cfg.get("custom_provider"), "model_name": cfg.get("model_name"), "api_base": cfg.get("api_base") or None, diff --git a/surfsense_backend/app/routes/search_spaces_routes.py b/surfsense_backend/app/routes/search_spaces_routes.py index 898077b7a..2cda04221 100644 --- a/surfsense_backend/app/routes/search_spaces_routes.py +++ b/surfsense_backend/app/routes/search_spaces_routes.py @@ -419,7 +419,7 @@ async def _get_llm_config_by_id( "id": cfg.get("id"), "name": cfg.get("name"), "description": cfg.get("description"), - "provider": cfg.get("provider"), + "provider": cfg.get("litellm_provider"), "custom_provider": cfg.get("custom_provider"), "model_name": cfg.get("model_name"), "api_base": cfg.get("api_base"), @@ -490,7 +490,7 @@ async def _get_image_gen_config_by_id( "id": cfg.get("id"), "name": cfg.get("name"), "description": cfg.get("description"), - "provider": cfg.get("provider"), + "provider": cfg.get("litellm_provider"), "custom_provider": cfg.get("custom_provider"), "model_name": cfg.get("model_name"), "api_base": cfg.get("api_base") or None, @@ -550,7 +550,7 @@ async def _get_vision_llm_config_by_id( "id": cfg.get("id"), "name": cfg.get("name"), "description": cfg.get("description"), - "provider": cfg.get("provider"), + "provider": cfg.get("litellm_provider"), "custom_provider": cfg.get("custom_provider"), "model_name": cfg.get("model_name"), "api_base": cfg.get("api_base") or None, diff --git a/surfsense_backend/app/schemas/model_connections.py b/surfsense_backend/app/schemas/model_connections.py index ea1ec4e88..306dd63c8 100644 --- a/surfsense_backend/app/schemas/model_connections.py +++ b/surfsense_backend/app/schemas/model_connections.py @@ -29,7 +29,7 @@ class ModelRead(BaseModel): class ConnectionRead(BaseModel): id: int protocol: ConnectionProtocol | str - native_provider: str | None = None + litellm_provider: str | None = None base_url: str | None = None extra: dict[str, Any] = Field(default_factory=dict) scope: ConnectionScope | str @@ -48,7 +48,7 @@ class ConnectionRead(BaseModel): class ConnectionCreate(BaseModel): protocol: ConnectionProtocol - native_provider: str | None = None + litellm_provider: str | None = Field(None, max_length=100) base_url: str | None = Field(None, max_length=500) api_key: str | None = None extra: dict[str, Any] = Field(default_factory=dict) @@ -58,7 +58,7 @@ class ConnectionCreate(BaseModel): class ConnectionUpdate(BaseModel): - native_provider: str | None = None + litellm_provider: str | None = Field(None, max_length=100) base_url: str | None = Field(None, max_length=500) api_key: str | None = None extra: dict[str, Any] | None = None diff --git a/surfsense_backend/app/services/model_connection_service.py b/surfsense_backend/app/services/model_connection_service.py index 42a4792a4..5e5b231f9 100644 --- a/surfsense_backend/app/services/model_connection_service.py +++ b/surfsense_backend/app/services/model_connection_service.py @@ -2,7 +2,6 @@ from __future__ import annotations -import asyncio import contextlib import logging from dataclasses import dataclass @@ -13,8 +12,7 @@ import httpx import litellm from app.db import Connection, ConnectionProtocol, Model, ModelSource -from app.services.model_resolver import NATIVE_PROVIDER_PREFIX, ensure_v1, to_litellm -from app.services.provider_api_base import resolve_api_base +from app.services.model_resolver import ensure_v1, to_litellm logger = logging.getLogger(__name__) @@ -36,6 +34,13 @@ def _auth_headers(conn: Connection) -> dict[str, str]: return {"Authorization": f"Bearer {conn.api_key}"} +def _anthropic_headers(conn: Connection) -> dict[str, str]: + headers = {"anthropic-version": "2023-06-01"} + if conn.api_key: + headers["x-api-key"] = conn.api_key + return headers + + def _docker_hint(url: str | None, exc_or_status: Any) -> str: raw = str(exc_or_status) if not url: @@ -56,24 +61,26 @@ def _docker_hint(url: str | None, exc_or_status: Any) -> str: async def verify_connection(conn: Connection) -> VerifyResult: - if not conn.base_url and conn.protocol in ( - ConnectionProtocol.OLLAMA, - ConnectionProtocol.OPENAI_COMPATIBLE, - ): + if not conn.base_url: return VerifyResult("UNREACHABLE", False, "Base URL is required.") if conn.protocol == ConnectionProtocol.OLLAMA: url = f"{conn.base_url.rstrip('/')}/api/version" elif conn.protocol == ConnectionProtocol.OPENAI_COMPATIBLE: url = f"{ensure_v1(conn.base_url)}/models" + elif conn.protocol == ConnectionProtocol.ANTHROPIC: + url = f"{conn.base_url.rstrip('/')}/models" else: - # Native providers do not share one cheap health endpoint. The model - # probe exercises the real path and is the authoritative check. - return VerifyResult("OK", True, "Native provider configuration accepted.") + return VerifyResult("UNREACHABLE", False, "Unsupported connection protocol.") try: async with httpx.AsyncClient(timeout=VERIFY_TIMEOUT_SECONDS) as client: - response = await client.get(url, headers=_auth_headers(conn)) + headers = ( + _anthropic_headers(conn) + if conn.protocol == ConnectionProtocol.ANTHROPIC + else _auth_headers(conn) + ) + response = await client.get(url, headers=headers) if response.status_code in (401, 403): return VerifyResult("AUTH_FAILED", False, "Authentication failed.") if response.status_code == 404: @@ -156,39 +163,25 @@ async def _discover_openai_shaped_models(conn: Connection, base_url: str | None) ] -def _litellm_valid_model_ids(provider: str, api_key: str | None) -> list[str]: - if not api_key: +async def _discover_anthropic_models(conn: Connection) -> list[dict[str, Any]]: + if not conn.base_url: return [] - try: - models = litellm.get_valid_models( - check_provider_endpoint=True, - custom_llm_provider=provider, - api_key=api_key, - ) - except Exception as exc: - logger.warning("LiteLLM model discovery failed for provider %s: %s", provider, exc) - return [] - - provider_prefix = f"{provider}/" - return [ - model.removeprefix(provider_prefix) - for model in models - if isinstance(model, str) and model.strip() - ] - - -async def _discover_litellm_native_models(conn: Connection, provider: str) -> list[dict[str, Any]]: - model_ids = await asyncio.to_thread(_litellm_valid_model_ids, provider, conn.api_key) + url = f"{conn.base_url.rstrip('/')}/models" + async with httpx.AsyncClient(timeout=DISCOVERY_TIMEOUT_SECONDS) as client: + response = await client.get(url, headers=_anthropic_headers(conn)) + response.raise_for_status() + models = response.json().get("data", []) return [ { - "model_id": model_id, - "display_name": model_id, + "model_id": item.get("id"), + "display_name": item.get("display_name") or item.get("id"), "source": ModelSource.DISCOVERED, - "capabilities": derive_capabilities(conn, model_id), - "metadata": {}, + "capabilities": derive_capabilities(conn, item.get("id"), item), + "metadata": item, } - for model_id in model_ids + for item in models + if item.get("id") ] @@ -231,20 +224,10 @@ async def discover_models(conn: Connection) -> list[dict[str, Any]]: ] elif conn.protocol == ConnectionProtocol.OPENAI_COMPATIBLE: results = await _discover_openai_shaped_models(conn, conn.base_url) + elif conn.protocol == ConnectionProtocol.ANTHROPIC: + results = await _discover_anthropic_models(conn) else: - provider_key = (conn.native_provider or "").upper() - provider = NATIVE_PROVIDER_PREFIX.get(provider_key, provider_key.lower()) - api_base = resolve_api_base( - provider=provider_key, - provider_prefix=provider, - config_api_base=conn.base_url, - ) - if api_base: - results = await _discover_openai_shaped_models(conn, api_base) - elif provider: - results = await _discover_litellm_native_models(conn, provider) - else: - results = [] + results = [] if allowlist: results = [item for item in results if item["model_id"] in allowlist] diff --git a/surfsense_backend/tests/e2e/fixtures/global_llm_config.yaml b/surfsense_backend/tests/e2e/fixtures/global_llm_config.yaml index 017fa1eb3..9ea5e1a29 100644 --- a/surfsense_backend/tests/e2e/fixtures/global_llm_config.yaml +++ b/surfsense_backend/tests/e2e/fixtures/global_llm_config.yaml @@ -19,7 +19,7 @@ # so the resolved auto-pin id is never sent to a real LLM provider. # The values below only need to pass # auto_model_pin_service._is_usable_global_config() -# which requires id / model_name / provider / api_key all truthy. +# which requires id / model_name / litellm_provider / api_key all truthy. # # Why TWO entries (premium + free): # auto_model_pin_service.resolve_or_get_pinned_llm_config_id() splits @@ -44,9 +44,10 @@ global_llm_configs: anonymous_enabled: false seo_enabled: false quality_score: 1.0 - provider: "OPENAI" + litellm_provider: "openai" model_name: "fake-e2e-model-premium" api_key: "fake-e2e-api-key-not-for-production" + api_base: "https://api.openai.com/v1" supports_image_input: false quota_reserve_tokens: 1024 rpm: 1000 @@ -60,9 +61,10 @@ global_llm_configs: anonymous_enabled: false seo_enabled: false quality_score: 1.0 - provider: "OPENAI" + litellm_provider: "openai" model_name: "fake-e2e-model-free" api_key: "fake-e2e-api-key-not-for-production" + api_base: "https://api.openai.com/v1" supports_image_input: false quota_reserve_tokens: 1024 rpm: 1000 diff --git a/surfsense_backend/tests/unit/routes/test_global_configs_is_premium.py b/surfsense_backend/tests/unit/routes/test_global_configs_is_premium.py index 2b6c76485..fff61f14b 100644 --- a/surfsense_backend/tests/unit/routes/test_global_configs_is_premium.py +++ b/surfsense_backend/tests/unit/routes/test_global_configs_is_premium.py @@ -25,7 +25,7 @@ _IMAGE_FIXTURE: list[dict] = [ { "id": -1, "name": "DALL-E 3", - "provider": "OPENAI", + "litellm_provider": "openai", "model_name": "dall-e-3", "api_key": "sk-test", "billing_tier": "free", @@ -33,7 +33,7 @@ _IMAGE_FIXTURE: list[dict] = [ { "id": -2, "name": "GPT-Image 1 (premium)", - "provider": "OPENAI", + "litellm_provider": "openai", "model_name": "gpt-image-1", "api_key": "sk-test", "billing_tier": "premium", @@ -41,7 +41,7 @@ _IMAGE_FIXTURE: list[dict] = [ { "id": -20_001, "name": "google/gemini-2.5-flash-image (OpenRouter)", - "provider": "OPENROUTER", + "litellm_provider": "openrouter", "model_name": "google/gemini-2.5-flash-image", "api_key": "sk-or-test", "api_base": "https://openrouter.ai/api/v1", @@ -54,7 +54,7 @@ _VISION_FIXTURE: list[dict] = [ { "id": -1, "name": "GPT-4o Vision", - "provider": "OPENAI", + "litellm_provider": "openai", "model_name": "gpt-4o", "api_key": "sk-test", "billing_tier": "free", @@ -62,7 +62,7 @@ _VISION_FIXTURE: list[dict] = [ { "id": -2, "name": "Claude 3.5 Sonnet (premium)", - "provider": "ANTHROPIC", + "litellm_provider": "anthropic", "model_name": "claude-3-5-sonnet", "api_key": "sk-ant-test", "billing_tier": "premium", @@ -70,7 +70,7 @@ _VISION_FIXTURE: list[dict] = [ { "id": -30_001, "name": "openai/gpt-4o (OpenRouter)", - "provider": "OPENROUTER", + "litellm_provider": "openrouter", "model_name": "openai/gpt-4o", "api_key": "sk-or-test", "api_base": "https://openrouter.ai/api/v1", diff --git a/surfsense_backend/tests/unit/routes/test_global_new_llm_configs_supports_image.py b/surfsense_backend/tests/unit/routes/test_global_new_llm_configs_supports_image.py index b47d9134b..67d1112f3 100644 --- a/surfsense_backend/tests/unit/routes/test_global_new_llm_configs_supports_image.py +++ b/surfsense_backend/tests/unit/routes/test_global_new_llm_configs_supports_image.py @@ -26,7 +26,7 @@ _FIXTURE: list[dict] = [ "id": -1, "name": "GPT-4o (explicit true)", "description": "vision-capable, explicit YAML override", - "provider": "OPENAI", + "litellm_provider": "openai", "model_name": "gpt-4o", "api_key": "sk-test", "billing_tier": "free", @@ -36,7 +36,7 @@ _FIXTURE: list[dict] = [ "id": -2, "name": "DeepSeek V3 (explicit false)", "description": "OpenRouter dynamic — modality-derived false", - "provider": "OPENROUTER", + "litellm_provider": "openrouter", "model_name": "deepseek/deepseek-v3.2-exp", "api_key": "sk-or-test", "api_base": "https://openrouter.ai/api/v1", @@ -47,7 +47,7 @@ _FIXTURE: list[dict] = [ "id": -10_010, "name": "Unannotated GPT-4o", "description": "no flag set — resolver should derive True via LiteLLM", - "provider": "OPENAI", + "litellm_provider": "openai", "model_name": "gpt-4o", "api_key": "sk-test", "billing_tier": "free", @@ -57,7 +57,7 @@ _FIXTURE: list[dict] = [ "id": -10_011, "name": "Unannotated unknown model", "description": "unmapped — default-allow True", - "provider": "CUSTOM", + "litellm_provider": "custom", "custom_provider": "brand_new_proxy", "model_name": "brand-new-model-x9", "api_key": "sk-test", diff --git a/surfsense_backend/tests/unit/services/test_model_connections.py b/surfsense_backend/tests/unit/services/test_model_connections.py index 98042501b..797f794b1 100644 --- a/surfsense_backend/tests/unit/services/test_model_connections.py +++ b/surfsense_backend/tests/unit/services/test_model_connections.py @@ -2,11 +2,12 @@ from app.services.global_model_catalog import materialize_global_model_catalog from app.services.model_resolver import ensure_v1, to_litellm -def test_openai_compatible_resolver_normalizes_v1() -> None: +def test_openai_compatible_resolver_uses_explicit_api_base() -> None: model, kwargs = to_litellm( { "protocol": "OPENAI_COMPATIBLE", - "base_url": "http://host.docker.internal:1234", + "litellm_provider": "openai", + "base_url": "http://host.docker.internal:1234/v1", "api_key": "local-key", "extra": {}, }, @@ -23,6 +24,7 @@ def test_ollama_resolver_uses_native_api_base() -> None: model, kwargs = to_litellm( { "protocol": "OLLAMA", + "litellm_provider": "ollama_chat", "base_url": "http://host.docker.internal:11434", "api_key": None, "extra": {}, @@ -40,9 +42,10 @@ def test_global_materialization_preserves_tier_and_keeps_key_server_side() -> No { "id": -101, "name": "OpenRouter Free", - "provider": "OPENROUTER", + "litellm_provider": "openrouter", "model_name": "meta-llama/llama-3.1-8b-instruct:free", "api_key": "sk-global-secret", + "api_base": "https://openrouter.ai/api/v1", "billing_tier": "free", "anonymous_enabled": True, "seo_enabled": True, @@ -52,9 +55,10 @@ def test_global_materialization_preserves_tier_and_keeps_key_server_side() -> No { "id": -102, "name": "OpenRouter Premium", - "provider": "OPENROUTER", + "litellm_provider": "openrouter", "model_name": "anthropic/claude-sonnet-4", "api_key": "sk-global-secret", + "api_base": "https://openrouter.ai/api/v1", "billing_tier": "premium", }, ],