diff --git a/surfsense_backend/alembic/versions/156_add_model_connections.py b/surfsense_backend/alembic/versions/156_add_model_connections.py index 185debca4..64614db99 100644 --- a/surfsense_backend/alembic/versions/156_add_model_connections.py +++ b/surfsense_backend/alembic/versions/156_add_model_connections.py @@ -17,13 +17,6 @@ branch_labels: str | Sequence[str] | None = None depends_on: str | Sequence[str] | None = None -connection_protocol = postgresql.ENUM( - "OLLAMA", - "OPENAI_COMPATIBLE", - "ANTHROPIC", - name="connectionprotocol", - create_type=False, -) connection_scope = postgresql.ENUM( "GLOBAL", "SEARCH_SPACE", @@ -73,36 +66,67 @@ def _add_searchspace_column_if_missing(column_name: str) -> None: op.add_column("searchspaces", sa.Column(column_name, sa.Integer(), nullable=True)) +def _drop_column_if_exists(table_name: str, column_name: str) -> None: + if _column_exists(table_name, column_name): + op.drop_column(table_name, column_name) + + +def _drop_index_if_exists(table_name: str, index_name: str) -> None: + if _index_exists(table_name, index_name): + op.drop_index(index_name, table_name=table_name) + + def upgrade() -> None: bind = op.get_bind() - connection_protocol.create(bind, checkfirst=True) - op.execute("ALTER TYPE connectionprotocol ADD VALUE IF NOT EXISTS 'ANTHROPIC'") connection_scope.create(bind, checkfirst=True) model_source.create(bind, checkfirst=True) if _table_exists("connections"): - if _column_exists("connections", "native_provider") and not _column_exists( - "connections", "litellm_provider" + if _column_exists("connections", "litellm_provider") and not _column_exists( + "connections", "provider" + ): + op.alter_column( + "connections", + "litellm_provider", + new_column_name="provider", + existing_type=sa.String(length=100), + existing_nullable=True, + ) + op.alter_column( + "connections", + "provider", + existing_type=sa.String(length=100), + nullable=False, + ) + elif _column_exists("connections", "native_provider") and not _column_exists( + "connections", "provider" ): op.alter_column( "connections", "native_provider", - new_column_name="litellm_provider", + new_column_name="provider", existing_type=sa.String(length=100), existing_nullable=True, ) - elif not _column_exists("connections", "litellm_provider"): + op.alter_column( + "connections", + "provider", + existing_type=sa.String(length=100), + nullable=False, + ) + elif not _column_exists("connections", "provider"): op.add_column( "connections", - sa.Column("litellm_provider", sa.String(length=100), nullable=True), + sa.Column("provider", sa.String(length=100), nullable=False), ) + _drop_index_if_exists("connections", "ix_connections_protocol") + _drop_column_if_exists("connections", "protocol") else: op.create_table( "connections", sa.Column("id", sa.Integer(), nullable=False), sa.Column("created_at", sa.DateTime(timezone=True), nullable=False), - sa.Column("protocol", connection_protocol, nullable=False), - sa.Column("litellm_provider", sa.String(length=100), nullable=True), + sa.Column("provider", sa.String(length=100), nullable=False), sa.Column("base_url", sa.String(length=500), nullable=True), sa.Column("api_key", sa.String(), nullable=True), sa.Column( @@ -131,18 +155,20 @@ def upgrade() -> None: sa.PrimaryKeyConstraint("id"), ) if _index_exists("connections", "ix_connections_native_provider") and not _index_exists( - "connections", "ix_connections_litellm_provider" + "connections", "ix_connections_provider" ): op.execute( "ALTER INDEX ix_connections_native_provider " - "RENAME TO ix_connections_litellm_provider" + "RENAME TO ix_connections_provider" ) - _create_index_if_missing("ix_connections_protocol", "connections", ["protocol"]) - _create_index_if_missing( - "ix_connections_litellm_provider", - "connections", - ["litellm_provider"], - ) + if _index_exists("connections", "ix_connections_litellm_provider") and not _index_exists( + "connections", "ix_connections_provider" + ): + op.execute( + "ALTER INDEX ix_connections_litellm_provider " + "RENAME TO ix_connections_provider" + ) + _create_index_if_missing("ix_connections_provider", "connections", ["provider"]) _create_index_if_missing("ix_connections_scope", "connections", ["scope"]) if not _table_exists("models"): @@ -159,24 +185,11 @@ def upgrade() -> None: server_default="DISCOVERED", nullable=False, ), - sa.Column( - "capabilities", - postgresql.JSONB(astext_type=sa.Text()), - server_default=sa.text("'{}'::jsonb"), - nullable=False, - ), - sa.Column( - "capabilities_declared", - postgresql.JSONB(astext_type=sa.Text()), - server_default=sa.text("'{}'::jsonb"), - nullable=False, - ), - sa.Column( - "capabilities_verified", - postgresql.JSONB(astext_type=sa.Text()), - server_default=sa.text("'{}'::jsonb"), - nullable=False, - ), + sa.Column("supports_chat", sa.Boolean(), nullable=True), + sa.Column("max_input_tokens", sa.Integer(), nullable=True), + sa.Column("supports_image_input", sa.Boolean(), nullable=True), + sa.Column("supports_tools", sa.Boolean(), nullable=True), + sa.Column("supports_image_generation", sa.Boolean(), nullable=True), sa.Column( "capabilities_override", postgresql.JSONB(astext_type=sa.Text()), @@ -198,6 +211,24 @@ def upgrade() -> None: "connection_id", "model_id", name="uq_models_connection_model_id" ), ) + else: + if not _column_exists("models", "supports_chat"): + op.add_column("models", sa.Column("supports_chat", sa.Boolean(), nullable=True)) + if not _column_exists("models", "max_input_tokens"): + op.add_column("models", sa.Column("max_input_tokens", sa.Integer(), nullable=True)) + if not _column_exists("models", "supports_image_input"): + op.add_column( + "models", sa.Column("supports_image_input", sa.Boolean(), nullable=True) + ) + if not _column_exists("models", "supports_tools"): + op.add_column("models", sa.Column("supports_tools", sa.Boolean(), nullable=True)) + if not _column_exists("models", "supports_image_generation"): + op.add_column( + "models", sa.Column("supports_image_generation", sa.Boolean(), nullable=True) + ) + _drop_column_if_exists("models", "capabilities") + _drop_column_if_exists("models", "capabilities_declared") + _drop_column_if_exists("models", "capabilities_verified") _create_index_if_missing("ix_models_connection_id", "models", ["connection_id"]) _create_index_if_missing("ix_models_model_id", "models", ["model_id"]) _create_index_if_missing("ix_models_billing_tier", "models", ["billing_tier"]) @@ -206,6 +237,8 @@ def upgrade() -> None: _add_searchspace_column_if_missing("image_gen_model_id") _add_searchspace_column_if_missing("vision_model_id") + op.execute("DROP TYPE IF EXISTS connectionprotocol") + def downgrade() -> None: op.drop_column("searchspaces", "vision_model_id") @@ -218,11 +251,9 @@ def downgrade() -> None: op.drop_table("models") op.drop_index(op.f("ix_connections_scope"), table_name="connections") - op.drop_index(op.f("ix_connections_litellm_provider"), table_name="connections") - op.drop_index(op.f("ix_connections_protocol"), table_name="connections") + op.drop_index(op.f("ix_connections_provider"), table_name="connections") op.drop_table("connections") bind = op.get_bind() model_source.drop(bind, checkfirst=True) connection_scope.drop(bind, checkfirst=True) - connection_protocol.drop(bind, checkfirst=True) diff --git a/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/deliverables/tools/generate_image.py b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/deliverables/tools/generate_image.py index fda327750..505831faa 100644 --- a/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/deliverables/tools/generate_image.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/deliverables/tools/generate_image.py @@ -29,6 +29,7 @@ from app.services.auto_model_pin_service import ( auto_model_candidates, choose_auto_model_candidate, ) +from app.services.model_capabilities import has_capability from app.services.model_resolver import to_litellm from app.utils.signed_image_urls import generate_image_token @@ -146,9 +147,7 @@ def create_generate_image_tool( if config_id < 0: global_model = _get_global_model(config_id) - if not global_model or not ( - global_model.get("capabilities") or {} - ).get("image_gen"): + if not global_model or not has_capability(global_model, "image_gen"): err = f"Image generation model {config_id} not found" return _failed({"error": err}, error=err) global_connection = _get_global_connection( @@ -191,7 +190,7 @@ def create_generate_image_tool( ): err = f"Image generation model {config_id} not found" return _failed({"error": err}, error=err) - if not (db_model.capabilities or {}).get("image_gen"): + if not has_capability(db_model, "image_gen"): err = f"Model {config_id} is not image-generation capable" return _failed({"error": err}, error=err) diff --git a/surfsense_backend/app/agents/chat/runtime/llm_config.py b/surfsense_backend/app/agents/chat/runtime/llm_config.py index b9344e001..efc188df8 100644 --- a/surfsense_backend/app/agents/chat/runtime/llm_config.py +++ b/surfsense_backend/app/agents/chat/runtime/llm_config.py @@ -49,16 +49,19 @@ def _sanitize_messages(messages: list[BaseMessage]) -> list[BaseMessage]: reject the blank text. The OpenAI spec says ``content`` should be ``null`` when an assistant message only carries tool calls. """ + sanitized: list[BaseMessage] = [] for msg in messages: - if isinstance(msg.content, list): - msg.content = _sanitize_content(msg.content) + next_msg = msg.model_copy(deep=True) + if isinstance(next_msg.content, list): + next_msg.content = _sanitize_content(next_msg.content) if ( - isinstance(msg, AIMessage) - and (not msg.content or msg.content == "") - and getattr(msg, "tool_calls", None) + isinstance(next_msg, AIMessage) + and (not next_msg.content or next_msg.content == "") + and getattr(next_msg, "tool_calls", None) ): - msg.content = None # type: ignore[assignment] - return messages + next_msg.content = None # type: ignore[assignment] + sanitized.append(next_msg) + return sanitized class SanitizedChatLiteLLM(ChatLiteLLM): @@ -89,6 +92,22 @@ class SanitizedChatLiteLLM(ChatLiteLLM): ): yield chunk + async def _agenerate( + self, + messages: list[BaseMessage], + stop: list[str] | None = None, + run_manager: AsyncCallbackManagerForLLMRun | None = None, + stream: bool | None = None, + **kwargs: Any, + ) -> ChatResult: + return await super()._agenerate( + _sanitize_messages(messages), + stop=stop, + run_manager=run_manager, + stream=stream, + **kwargs, + ) + def _attach_model_profile(llm: ChatLiteLLM, model_string: str) -> None: """Attach a ``profile`` dict to ChatLiteLLM with model context metadata.""" @@ -210,7 +229,7 @@ class AgentConfig: # BYOK rows have no curated flag; ask LiteLLM (default-allow on # unknown). The streaming safety net still blocks explicit text-only. supports_image_input=derive_supports_image_input( - litellm_provider=provider_value.lower(), + provider=provider_value.lower(), model_name=config.model_name, base_model=base_model, custom_provider=config.custom_provider, @@ -229,7 +248,7 @@ class AgentConfig: system_instructions = yaml_config.get("system_instructions", "") - provider = yaml_config.get("litellm_provider", "") + provider = yaml_config.get("provider") or yaml_config.get("litellm_provider", "") model_name = yaml_config.get("model_name", "") custom_provider = yaml_config.get("custom_provider") litellm_params = yaml_config.get("litellm_params") or {} @@ -245,7 +264,7 @@ class AgentConfig: supports_image_input = bool(yaml_config.get("supports_image_input")) else: supports_image_input = derive_supports_image_input( - litellm_provider=provider, + provider=provider, model_name=model_name, base_model=base_model, custom_provider=custom_provider, @@ -396,8 +415,8 @@ def create_chat_litellm_from_config(llm_config: dict) -> ChatLiteLLM | None: if llm_config.get("custom_provider"): model_string = f"{llm_config['custom_provider']}/{llm_config['model_name']}" else: - litellm_provider = llm_config.get("litellm_provider", "openai") - model_string = f"{litellm_provider}/{llm_config['model_name']}" + provider = llm_config.get("provider") or llm_config.get("litellm_provider", "openai") + model_string = f"{provider}/{llm_config['model_name']}" litellm_kwargs = { "model": model_string, diff --git a/surfsense_backend/app/app.py b/surfsense_backend/app/app.py index d3f5dce2a..6dfe6a776 100644 --- a/surfsense_backend/app/app.py +++ b/surfsense_backend/app/app.py @@ -33,7 +33,6 @@ from app.config import ( initialize_llm_router, initialize_openrouter_integration, initialize_pricing_registration, - initialize_vision_llm_router, ) from app.db import User, create_db_and_tables, get_async_session from app.exceptions import GENERIC_5XX_MESSAGE, ISSUES_URL, SurfSenseError @@ -622,7 +621,6 @@ async def lifespan(app: FastAPI): initialize_pricing_registration() initialize_llm_router() initialize_image_gen_router() - initialize_vision_llm_router() # Phase 1.7 — JIT warmup. Bounded so a stuck warmup never delays # worker readiness. ``shield`` so Uvicorn cancelling startup diff --git a/surfsense_backend/app/celery_app.py b/surfsense_backend/app/celery_app.py index 0e852b801..7ab98203e 100644 --- a/surfsense_backend/app/celery_app.py +++ b/surfsense_backend/app/celery_app.py @@ -115,14 +115,12 @@ def init_worker(**kwargs): initialize_llm_router, initialize_openrouter_integration, initialize_pricing_registration, - initialize_vision_llm_router, ) initialize_openrouter_integration() initialize_pricing_registration() initialize_llm_router() initialize_image_gen_router() - initialize_vision_llm_router() # Celery configuration, sourced from the central Config singleton diff --git a/surfsense_backend/app/config/__init__.py b/surfsense_backend/app/config/__init__.py index fd8b29116..eb2a7a18c 100644 --- a/surfsense_backend/app/config/__init__.py +++ b/surfsense_backend/app/config/__init__.py @@ -103,7 +103,7 @@ def load_global_llm_configs(): else None ) cfg["supports_image_input"] = derive_supports_image_input( - litellm_provider=cfg.get("litellm_provider"), + provider=cfg.get("provider") or cfg.get("litellm_provider"), model_name=cfg.get("model_name"), base_model=base_model, custom_provider=cfg.get("custom_provider"), @@ -122,7 +122,7 @@ def load_global_llm_configs(): # Stamp Auto (Fastest) ranking metadata. YAML configs are always # Tier A — operator-curated, locked first when premium-eligible. # The OpenRouter refresh tick later re-stamps health for any cfg - # whose litellm_provider == "openrouter" via _enrich_health. + # whose provider == "openrouter" via _enrich_health. try: from app.services.quality_score import static_score_yaml @@ -132,7 +132,7 @@ def load_global_llm_configs(): cfg["quality_score_static"] = static_q cfg["quality_score"] = static_q cfg["quality_score_health"] = None - # YAML cfgs whose litellm_provider is openrouter are also subject + # YAML cfgs whose provider is openrouter are also subject # to health gating against their own /endpoints data — a # hand-picked dead OR model is still dead. _enrich_health # re-stamps health_gated for them on the next refresh tick. @@ -362,8 +362,8 @@ def initialize_openrouter_integration(): else: print("Info: OpenRouter integration enabled but no models fetched") - # Image generation + vision LLM emissions are opt-in (issue L). - # Both reuse the catalogue already cached by ``service.initialize`` + # Image generation emissions reuse the catalogue already cached by + # ``service.initialize`` # so we don't make additional network calls here. if settings.get("image_generation_enabled"): try: @@ -377,18 +377,6 @@ def initialize_openrouter_integration(): except Exception as e: print(f"Warning: Failed to inject OpenRouter image-gen configs: {e}") - if settings.get("vision_enabled"): - try: - vision_configs = service.get_vision_llm_configs() - if vision_configs: - config.GLOBAL_VISION_LLM_CONFIGS.extend(vision_configs) - print( - f"Info: OpenRouter integration added {len(vision_configs)} " - f"vision LLM models" - ) - except Exception as e: - print(f"Warning: Failed to inject OpenRouter vision-LLM configs: {e}") - refresh_global_model_catalog() except Exception as e: print(f"Warning: Failed to initialize OpenRouter integration: {e}") @@ -399,7 +387,6 @@ def materialize_global_configs(): return materialize_global_model_catalog( chat_configs=getattr(config, "GLOBAL_LLM_CONFIGS", []), - vision_configs=getattr(config, "GLOBAL_VISION_LLM_CONFIGS", []), image_configs=getattr(config, "GLOBAL_IMAGE_GEN_CONFIGS", []), ) @@ -493,29 +480,9 @@ def initialize_image_gen_router(): def initialize_vision_llm_router(): - vision_configs = load_global_vision_llm_configs() - # Reuse the router settings already parsed at Config construction. The - # *configs* list is intentionally re-read from YAML (it must exclude the - # OpenRouter-injected dynamic models held in config.GLOBAL_VISION_LLM_CONFIGS). - router_settings = config.VISION_LLM_ROUTER_SETTINGS - - if not vision_configs: - print( - "Info: No global vision LLM configs found, " - "Vision LLM Auto mode will not be available" - ) - return - - try: - from app.services.vision_llm_router_service import VisionLLMRouterService - - VisionLLMRouterService.initialize(vision_configs, router_settings) - print( - f"Info: Vision LLM Router initialized with {len(vision_configs)} models " - f"(strategy: {router_settings.get('routing_strategy', 'usage-based-routing')})" - ) - except Exception as e: - print(f"Warning: Failed to initialize Vision LLM Router: {e}") + # Retired: vision Auto now uses shared capability-filtered model selection + # over GLOBAL/BYOK chat models with supports_image_input=true. + return class Config: @@ -874,7 +841,6 @@ class Config: GLOBAL_CONNECTIONS, GLOBAL_MODELS = _materialize_global_model_catalog( chat_configs=GLOBAL_LLM_CONFIGS, - vision_configs=GLOBAL_VISION_LLM_CONFIGS, image_configs=GLOBAL_IMAGE_GEN_CONFIGS, ) del _materialize_global_model_catalog diff --git a/surfsense_backend/app/db.py b/surfsense_backend/app/db.py index 4c628b05a..9053d7055 100644 --- a/surfsense_backend/app/db.py +++ b/surfsense_backend/app/db.py @@ -280,12 +280,6 @@ class VisionProvider(StrEnum): CUSTOM = "CUSTOM" -class ConnectionProtocol(StrEnum): - OLLAMA = "OLLAMA" - OPENAI_COMPATIBLE = "OPENAI_COMPATIBLE" - ANTHROPIC = "ANTHROPIC" - - class ConnectionScope(StrEnum): GLOBAL = "GLOBAL" SEARCH_SPACE = "SEARCH_SPACE" @@ -1662,8 +1656,7 @@ class Report(BaseModel, TimestampMixin): class Connection(BaseModel, TimestampMixin): __tablename__ = "connections" - protocol = Column(SQLAlchemyEnum(ConnectionProtocol), nullable=False, index=True) - litellm_provider = Column(String(100), nullable=True, index=True) + provider = Column(String(100), nullable=False, index=True) base_url = Column(String(500), nullable=True) api_key = Column(String, nullable=True) extra = Column(JSONB, nullable=False, default=dict, server_default="{}") @@ -1715,9 +1708,11 @@ class Model(BaseModel, TimestampMixin): default=ModelSource.DISCOVERED, server_default=ModelSource.DISCOVERED.value, ) - capabilities = Column(JSONB, nullable=False, default=dict, server_default="{}") - capabilities_declared = Column(JSONB, nullable=False, default=dict, server_default="{}") - capabilities_verified = Column(JSONB, nullable=False, default=dict, server_default="{}") + supports_chat = Column(Boolean, nullable=True) + max_input_tokens = Column(Integer, nullable=True) + supports_image_input = Column(Boolean, nullable=True) + supports_tools = Column(Boolean, nullable=True) + supports_image_generation = Column(Boolean, nullable=True) capabilities_override = Column(JSONB, nullable=False, default=dict, server_default="{}") embedding_dimension = Column(Integer, nullable=True) enabled = Column(Boolean, nullable=False, default=True, server_default="true") diff --git a/surfsense_backend/app/routes/anonymous_chat_routes.py b/surfsense_backend/app/routes/anonymous_chat_routes.py index aba1a3a12..84420e738 100644 --- a/surfsense_backend/app/routes/anonymous_chat_routes.py +++ b/surfsense_backend/app/routes/anonymous_chat_routes.py @@ -132,7 +132,7 @@ async def list_anonymous_models(): id=cfg.get("id", 0), name=cfg.get("name", ""), description=cfg.get("description"), - provider=cfg.get("litellm_provider", ""), + provider=cfg.get("provider") or cfg.get("litellm_provider", ""), model_name=cfg.get("model_name", ""), billing_tier=cfg.get("billing_tier", "free"), is_premium=cfg.get("billing_tier", "free") == "premium", @@ -161,7 +161,7 @@ async def get_anonymous_model(slug: str): id=cfg.get("id", 0), name=cfg.get("name", ""), description=cfg.get("description"), - provider=cfg.get("litellm_provider", ""), + provider=cfg.get("provider") or cfg.get("litellm_provider", ""), model_name=cfg.get("model_name", ""), billing_tier=cfg.get("billing_tier", "free"), is_premium=cfg.get("billing_tier", "free") == "premium", diff --git a/surfsense_backend/app/routes/image_generation_routes.py b/surfsense_backend/app/routes/image_generation_routes.py index 5be1cedf1..e8f14bd71 100644 --- a/surfsense_backend/app/routes/image_generation_routes.py +++ b/surfsense_backend/app/routes/image_generation_routes.py @@ -52,6 +52,7 @@ from app.services.auto_model_pin_service import ( choose_auto_model_candidate, ) from app.services.model_resolver import to_litellm +from app.services.model_capabilities import has_capability from app.users import current_active_user from app.utils.rbac import check_permission from app.utils.signed_image_urls import verify_image_token @@ -166,7 +167,7 @@ async def _execute_image_generation( if config_id < 0: global_model = _get_global_model(config_id) - if not global_model or not (global_model.get("capabilities") or {}).get("image_gen"): + if not global_model or not has_capability(global_model, "image_gen"): raise ValueError(f"Global image generation model {config_id} not found") global_connection = _get_global_connection(global_model["connection_id"]) if not global_connection: @@ -200,7 +201,7 @@ async def _execute_image_generation( raise ValueError(f"Image generation model {config_id} not found") if conn.user_id is not None and conn.user_id != search_space.user_id: raise ValueError(f"Image generation model {config_id} not found") - if not (db_model.capabilities or {}).get("image_gen"): + if not has_capability(db_model, "image_gen"): raise ValueError(f"Model {config_id} is not image-generation capable") model_string, resolved_kwargs = to_litellm( @@ -272,7 +273,7 @@ async def get_global_image_gen_configs( "id": cfg.get("id"), "name": cfg.get("name"), "description": cfg.get("description"), - "provider": cfg.get("litellm_provider"), + "provider": cfg.get("provider") or cfg.get("litellm_provider"), "custom_provider": cfg.get("custom_provider"), "model_name": cfg.get("model_name"), "api_base": cfg.get("api_base") or None, diff --git a/surfsense_backend/app/routes/model_connections_routes.py b/surfsense_backend/app/routes/model_connections_routes.py index cae951c3a..ecb86711e 100644 --- a/surfsense_backend/app/routes/model_connections_routes.py +++ b/surfsense_backend/app/routes/model_connections_routes.py @@ -8,7 +8,6 @@ from sqlalchemy.orm import selectinload from app.config import config from app.db import ( Connection, - ConnectionProtocol, ConnectionScope, Model, ModelSource, @@ -22,6 +21,7 @@ from app.schemas import ( ConnectionRead, ConnectionUpdate, ModelCreate, + ModelProviderRead, ModelRead, ModelRolesRead, ModelRolesUpdate, @@ -34,6 +34,8 @@ from app.services.model_connection_service import ( persist_verification, test_model, ) +from app.services.model_capabilities import has_capability +from app.services.provider_registry import REGISTRY from app.users import current_active_user from app.utils.rbac import check_permission @@ -41,16 +43,6 @@ router = APIRouter() logger = logging.getLogger(__name__) -def _default_litellm_provider(protocol: ConnectionProtocol | str) -> str: - protocol_value = getattr(protocol, "value", str(protocol)) - defaults = { - ConnectionProtocol.OLLAMA.value: "ollama_chat", - ConnectionProtocol.ANTHROPIC.value: "anthropic", - ConnectionProtocol.OPENAI_COMPATIBLE.value: "openai", - } - return defaults.get(protocol_value, "openai") - - def _model_read(model: Model | dict) -> ModelRead: return ModelRead.model_validate(model) @@ -68,8 +60,7 @@ def _connection_read(conn: Connection | dict, models: list[Model | dict] | None return ConnectionRead( id=conn.id, - protocol=conn.protocol, - litellm_provider=conn.litellm_provider, + provider=conn.provider, base_url=conn.base_url, extra=conn.extra or {}, scope=conn.scope, @@ -85,6 +76,60 @@ def _connection_read(conn: Connection | dict, models: list[Model | dict] | None ) +def _apply_model_facts(model: Model, facts: dict) -> None: + model.supports_chat = facts.get("supports_chat") + model.max_input_tokens = facts.get("max_input_tokens") + model.supports_image_input = facts.get("supports_image_input") + model.supports_tools = facts.get("supports_tools") + model.supports_image_generation = facts.get("supports_image_generation") + + +def _default_model_for(models: list[Model], capability: str) -> int | None: + for model in models: + if model.enabled and has_capability(model, capability): + return model.id + return None + + +async def _default_unset_roles( + session: AsyncSession, + conn: Connection, + models: list[Model], +) -> None: + if conn.scope != ConnectionScope.SEARCH_SPACE or conn.search_space_id is None: + return + search_space = await _get_search_space(session, conn.search_space_id) + if search_space.chat_model_id is None: + search_space.chat_model_id = _default_model_for(models, "chat") + if search_space.vision_model_id is None: + vision_default = None + if search_space.chat_model_id: + chat_model = next((m for m in models if m.id == search_space.chat_model_id), None) + if chat_model and has_capability(chat_model, "vision"): + vision_default = chat_model.id + search_space.vision_model_id = vision_default or _default_model_for(models, "vision") + if search_space.image_gen_model_id is None: + search_space.image_gen_model_id = _default_model_for(models, "image_gen") + + +@router.get("/model-providers", response_model=list[ModelProviderRead]) +async def list_model_providers(user: User = Depends(current_active_user)): + del user + local_only = {"ollama_chat", "lm_studio"} + return [ + ModelProviderRead( + provider=provider, + transport=spec.transport.value, + discovery=spec.discovery, + default_base_url=spec.default_base_url, + base_url_required=spec.base_url_required, + auth_style=spec.auth_style, + local_only=provider in local_only, + ) + for provider, spec in sorted(REGISTRY.items()) + ] + + async def _get_search_space(session: AsyncSession, search_space_id: int) -> SearchSpace: result = await session.execute(select(SearchSpace).where(SearchSpace.id == search_space_id)) search_space = result.scalars().first() @@ -180,8 +225,6 @@ async def create_connection( "You don't have permission to create model connections in this search space", ) payload = data.model_dump(exclude={"search_space_id"}) - if not payload.get("litellm_provider"): - payload["litellm_provider"] = _default_litellm_provider(data.protocol) conn = Connection( **payload, @@ -254,24 +297,21 @@ async def discover_connection_models( model_id=item["model_id"], display_name=item.get("display_name"), source=item["source"], - capabilities=item["capabilities"], - capabilities_declared=item["capabilities"], - capabilities_verified={}, capabilities_override={}, enabled=False, catalog=item.get("metadata") or {}, ) + _apply_model_facts(db_model, item) session.add(db_model) else: db_model.display_name = item.get("display_name") or db_model.display_name - db_model.capabilities_declared = item["capabilities"] - db_model.capabilities = { - **item["capabilities"], - **(db_model.capabilities_override or {}), - } + _apply_model_facts(db_model, item) db_model.catalog = item.get("metadata") or db_model.catalog await session.commit() conn = await _load_connection(session, connection_id) + await _default_unset_roles(session, conn, list(conn.models)) + await session.commit() + conn = await _load_connection(session, connection_id) return [_model_read(model) for model in conn.models] @@ -297,16 +337,17 @@ async def add_manual_model( model_id=model_id, display_name=data.display_name or None, source=ModelSource.MANUAL, - capabilities=capabilities, - capabilities_declared=capabilities, - capabilities_verified={}, capabilities_override={}, enabled=True, catalog={}, ) + _apply_model_facts(model, capabilities) session.add(model) await session.commit() await session.refresh(model) + conn = await _load_connection(session, connection_id) + await _default_unset_roles(session, conn, list(conn.models)) + await session.commit() return _model_read(model) @@ -327,11 +368,6 @@ async def update_model( update = data.model_dump(exclude_unset=True) for key, value in update.items(): setattr(model, key, value) - if "capabilities_override" in update: - model.capabilities = { - **(model.capabilities_declared or {}), - **(model.capabilities_override or {}), - } await session.commit() await session.refresh(model) return _model_read(model) diff --git a/surfsense_backend/app/routes/new_chat_routes.py b/surfsense_backend/app/routes/new_chat_routes.py index 0e4e557be..b5bc2571e 100644 --- a/surfsense_backend/app/routes/new_chat_routes.py +++ b/surfsense_backend/app/routes/new_chat_routes.py @@ -1741,12 +1741,11 @@ async def handle_new_chat( if not search_space: raise HTTPException(status_code=404, detail="Search space not found") - # Use agent_llm_id from search space for chat operations - # Positive IDs load from NewLLMConfig database table - # Negative IDs load from YAML global configs - # Falls back to -1 (first global config) if not configured + # Use the converged model-connections role for chat operations. + # Positive IDs load Model + Connection rows; negative IDs load + # virtual GLOBAL models; 0 means Auto. llm_config_id = ( - search_space.agent_llm_id if search_space.agent_llm_id is not None else -1 + search_space.chat_model_id if search_space.chat_model_id is not None else 0 ) # Release the read-transaction so we don't hold ACCESS SHARE locks @@ -2228,7 +2227,7 @@ async def regenerate_response( raise HTTPException(status_code=404, detail="Search space not found") llm_config_id = ( - search_space.agent_llm_id if search_space.agent_llm_id is not None else -1 + search_space.chat_model_id if search_space.chat_model_id is not None else 0 ) # Release the read-transaction so we don't hold ACCESS SHARE locks @@ -2393,7 +2392,7 @@ async def resume_chat( raise HTTPException(status_code=404, detail="Search space not found") llm_config_id = ( - search_space.agent_llm_id if search_space.agent_llm_id is not None else -1 + search_space.chat_model_id if search_space.chat_model_id is not None else 0 ) decisions = [d.model_dump() for d in request.decisions] diff --git a/surfsense_backend/app/routes/new_llm_config_routes.py b/surfsense_backend/app/routes/new_llm_config_routes.py index 531fa8730..adba5b5ae 100644 --- a/surfsense_backend/app/routes/new_llm_config_routes.py +++ b/surfsense_backend/app/routes/new_llm_config_routes.py @@ -57,7 +57,7 @@ def _serialize_byok_config(config: NewLLMConfig) -> NewLLMConfigRead: litellm_params.get("base_model") if isinstance(litellm_params, dict) else None ) supports_image_input = derive_supports_image_input( - litellm_provider=provider_value.lower(), + provider=provider_value.lower(), model_name=config.model_name, base_model=base_model, custom_provider=config.custom_provider, @@ -147,7 +147,7 @@ async def get_global_new_llm_configs( else None ) supports_image_input = derive_supports_image_input( - litellm_provider=cfg.get("litellm_provider"), + provider=cfg.get("provider") or cfg.get("litellm_provider"), model_name=cfg.get("model_name"), base_model=cfg_base_model, custom_provider=cfg.get("custom_provider"), @@ -157,7 +157,7 @@ async def get_global_new_llm_configs( "id": cfg.get("id"), "name": cfg.get("name"), "description": cfg.get("description"), - "provider": cfg.get("litellm_provider"), + "provider": cfg.get("provider") or cfg.get("litellm_provider"), "custom_provider": cfg.get("custom_provider"), "model_name": cfg.get("model_name"), "api_base": cfg.get("api_base") or None, diff --git a/surfsense_backend/app/routes/search_spaces_routes.py b/surfsense_backend/app/routes/search_spaces_routes.py index 2cda04221..7c5fbf28b 100644 --- a/surfsense_backend/app/routes/search_spaces_routes.py +++ b/surfsense_backend/app/routes/search_spaces_routes.py @@ -419,7 +419,7 @@ async def _get_llm_config_by_id( "id": cfg.get("id"), "name": cfg.get("name"), "description": cfg.get("description"), - "provider": cfg.get("litellm_provider"), + "provider": cfg.get("provider") or cfg.get("litellm_provider"), "custom_provider": cfg.get("custom_provider"), "model_name": cfg.get("model_name"), "api_base": cfg.get("api_base"), @@ -490,7 +490,7 @@ async def _get_image_gen_config_by_id( "id": cfg.get("id"), "name": cfg.get("name"), "description": cfg.get("description"), - "provider": cfg.get("litellm_provider"), + "provider": cfg.get("provider") or cfg.get("litellm_provider"), "custom_provider": cfg.get("custom_provider"), "model_name": cfg.get("model_name"), "api_base": cfg.get("api_base") or None, @@ -550,7 +550,7 @@ async def _get_vision_llm_config_by_id( "id": cfg.get("id"), "name": cfg.get("name"), "description": cfg.get("description"), - "provider": cfg.get("litellm_provider"), + "provider": cfg.get("provider") or cfg.get("litellm_provider"), "custom_provider": cfg.get("custom_provider"), "model_name": cfg.get("model_name"), "api_base": cfg.get("api_base") or None, diff --git a/surfsense_backend/app/routes/vision_llm_routes.py b/surfsense_backend/app/routes/vision_llm_routes.py index df218daac..b93d25b9c 100644 --- a/surfsense_backend/app/routes/vision_llm_routes.py +++ b/surfsense_backend/app/routes/vision_llm_routes.py @@ -96,7 +96,7 @@ async def get_global_vision_llm_configs( "id": cfg.get("id"), "name": cfg.get("name"), "description": cfg.get("description"), - "provider": cfg.get("litellm_provider"), + "provider": cfg.get("provider") or cfg.get("litellm_provider"), "custom_provider": cfg.get("custom_provider"), "model_name": cfg.get("model_name"), "api_base": cfg.get("api_base") or None, diff --git a/surfsense_backend/app/schemas/__init__.py b/surfsense_backend/app/schemas/__init__.py index 2a06eca5c..dde9fcef1 100644 --- a/surfsense_backend/app/schemas/__init__.py +++ b/surfsense_backend/app/schemas/__init__.py @@ -49,6 +49,7 @@ from .model_connections import ( ConnectionRead, ConnectionUpdate, ModelCreate, + ModelProviderRead, ModelRead, ModelRolesRead, ModelRolesUpdate, diff --git a/surfsense_backend/app/schemas/model_connections.py b/surfsense_backend/app/schemas/model_connections.py index 306dd63c8..c081a193d 100644 --- a/surfsense_backend/app/schemas/model_connections.py +++ b/surfsense_backend/app/schemas/model_connections.py @@ -4,7 +4,7 @@ from typing import Any from pydantic import BaseModel, ConfigDict, Field -from app.db import ConnectionProtocol, ConnectionScope, ModelSource +from app.db import ConnectionScope, ModelSource class ModelRead(BaseModel): @@ -13,9 +13,11 @@ class ModelRead(BaseModel): model_id: str display_name: str | None = None source: ModelSource | str - capabilities: dict[str, Any] - capabilities_declared: dict[str, Any] = Field(default_factory=dict) - capabilities_verified: dict[str, Any] = Field(default_factory=dict) + supports_chat: bool | None = None + max_input_tokens: int | None = None + supports_image_input: bool | None = None + supports_tools: bool | None = None + supports_image_generation: bool | None = None capabilities_override: dict[str, Any] = Field(default_factory=dict) embedding_dimension: int | None = None enabled: bool @@ -28,8 +30,7 @@ class ModelRead(BaseModel): class ConnectionRead(BaseModel): id: int - protocol: ConnectionProtocol | str - litellm_provider: str | None = None + provider: str base_url: str | None = None extra: dict[str, Any] = Field(default_factory=dict) scope: ConnectionScope | str @@ -47,8 +48,7 @@ class ConnectionRead(BaseModel): class ConnectionCreate(BaseModel): - protocol: ConnectionProtocol - litellm_provider: str | None = Field(None, max_length=100) + provider: str = Field(..., max_length=100) base_url: str | None = Field(None, max_length=500) api_key: str | None = None extra: dict[str, Any] = Field(default_factory=dict) @@ -58,7 +58,7 @@ class ConnectionCreate(BaseModel): class ConnectionUpdate(BaseModel): - litellm_provider: str | None = Field(None, max_length=100) + provider: str | None = Field(None, max_length=100) base_url: str | None = Field(None, max_length=500) api_key: str | None = None extra: dict[str, Any] | None = None @@ -79,9 +79,24 @@ class ModelCreate(BaseModel): class ModelUpdate(BaseModel): display_name: str | None = Field(None, max_length=255) enabled: bool | None = None + supports_chat: bool | None = None + max_input_tokens: int | None = None + supports_image_input: bool | None = None + supports_tools: bool | None = None + supports_image_generation: bool | None = None capabilities_override: dict[str, Any] | None = None +class ModelProviderRead(BaseModel): + provider: str + transport: str + discovery: str + default_base_url: str | None = None + base_url_required: bool + auth_style: str + local_only: bool = False + + class VerifyConnectionResponse(BaseModel): status: str ok: bool diff --git a/surfsense_backend/app/services/auto_model_pin_service.py b/surfsense_backend/app/services/auto_model_pin_service.py index ee8c4b8dc..652029e76 100644 --- a/surfsense_backend/app/services/auto_model_pin_service.py +++ b/surfsense_backend/app/services/auto_model_pin_service.py @@ -27,6 +27,7 @@ from sqlalchemy.orm import selectinload from app.config import config from app.db import Connection, Model, NewChatThread +from app.services.model_capabilities import has_capability from app.services.quality_score import _QUALITY_TOP_K from app.services.token_quota_service import TokenQuotaService @@ -62,18 +63,13 @@ def _is_usable_global_config(cfg: dict) -> bool: return bool( cfg.get("id") is not None and cfg.get("model_name") - and cfg.get("litellm_provider") + and (cfg.get("provider") or cfg.get("litellm_provider")) and cfg.get("api_key") ) def _has_capability(model: dict | Model, capability: str) -> bool: - caps = ( - model.get("capabilities", {}) - if isinstance(model, dict) - else model.capabilities or {} - ) - return bool(caps.get(capability)) + return has_capability(model, capability) def _prune_runtime_cooldowns(now_ts: float | None = None) -> None: @@ -196,7 +192,7 @@ def _cfg_supports_image_input(cfg: dict) -> bool: else None ) return derive_supports_image_input( - litellm_provider=cfg.get("litellm_provider"), + provider=cfg.get("provider") or cfg.get("litellm_provider"), model_name=cfg.get("model_name"), base_model=base_model, custom_provider=cfg.get("custom_provider"), @@ -253,9 +249,13 @@ def _global_candidates( "model_id": model.get("model_id"), "source": "global", "connection": connection, - "capabilities": model.get("capabilities") or {}, + "supports_chat": model.get("supports_chat"), + "supports_image_input": model.get("supports_image_input"), + "supports_tools": model.get("supports_tools"), + "supports_image_generation": model.get("supports_image_generation"), + "capabilities_override": model.get("capabilities_override") or {}, "billing_tier": model.get("billing_tier", "free"), - "litellm_provider": connection.get("litellm_provider"), + "provider": connection.get("provider"), "model_name": model.get("model_id"), "auto_pin_tier": catalog.get("auto_pin_tier") or cfg.get("auto_pin_tier") @@ -310,9 +310,13 @@ async def _db_candidates( "model_id": model.model_id, "source": "db", "connection": conn, - "capabilities": model.capabilities or {}, + "supports_chat": model.supports_chat, + "supports_image_input": model.supports_image_input, + "supports_tools": model.supports_tools, + "supports_image_generation": model.supports_image_generation, + "capabilities_override": model.capabilities_override or {}, "billing_tier": "byok", - "litellm_provider": conn.litellm_provider, + "provider": conn.provider, "model_name": model.model_id, "auto_pin_tier": catalog.get("auto_pin_tier") or "BYOK", "quality_score": catalog.get("quality_score") or 75, @@ -357,7 +361,7 @@ def _is_preferred_premium_auto_config(cfg: dict) -> bool: return ( cfg.get("source") == "global" and _tier_of(cfg) == "premium" - and str(cfg.get("litellm_provider", "")).lower() == "azure" + and str(cfg.get("provider", "")).lower() == "azure" and str(cfg.get("model_name", "")).lower() == "gpt-5.4" ) diff --git a/surfsense_backend/app/services/global_model_catalog.py b/surfsense_backend/app/services/global_model_catalog.py index e40b3a942..ca1249497 100644 --- a/surfsense_backend/app/services/global_model_catalog.py +++ b/surfsense_backend/app/services/global_model_catalog.py @@ -18,8 +18,7 @@ def _connection_key(conn: dict[str, Any]) -> tuple[Any, ...]: # Deliberately includes api_key because two operator-owned credentials for # the same provider/base can have different quota/rate limits upstream. return ( - conn.get("protocol"), - conn.get("litellm_provider"), + conn.get("provider"), conn.get("base_url"), conn.get("api_key"), _freeze(conn.get("extra") or {}), @@ -34,16 +33,6 @@ def _freeze(value: Any) -> Any: return value -def _capabilities_for(role: str, config: dict[str, Any]) -> dict[str, bool]: - return { - "chat": role == "chat", - "vision": role == "vision" or bool(config.get("supports_image_input")), - "image_gen": role == "image_gen", - "embedding": False, - "tools": bool(config.get("supports_tools", False)), - } - - def _catalog_metadata(config: dict[str, Any]) -> dict[str, Any]: return { "billing_tier": config.get("billing_tier", "free"), @@ -72,7 +61,6 @@ def _catalog_metadata(config: dict[str, Any]) -> dict[str, Any]: def materialize_global_model_catalog( *, chat_configs: list[dict[str, Any]], - vision_configs: list[dict[str, Any]], image_configs: list[dict[str, Any]], ) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]: connections: list[dict[str, Any]] = [] @@ -109,9 +97,13 @@ def materialize_global_model_catalog( "model_id": config["model_name"], "display_name": config.get("name") or config["model_name"], "source": "MANUAL", - "capabilities": _capabilities_for(role, config), - "capabilities_declared": _capabilities_for(role, config), - "capabilities_verified": _capabilities_for(role, config), + "supports_chat": role == "chat", + "max_input_tokens": config.get("max_input_tokens"), + "supports_image_input": ( + role == "chat" and bool(config.get("supports_image_input")) + ), + "supports_tools": bool(config.get("supports_tools", False)), + "supports_image_generation": role == "image_gen", "capabilities_override": {}, "embedding_dimension": None, "enabled": True, @@ -125,10 +117,6 @@ def materialize_global_model_catalog( if cfg.get("is_auto_mode"): continue add_config(cfg, "chat") - for cfg in vision_configs: - if cfg.get("is_auto_mode"): - continue - add_config(cfg, "vision") for cfg in image_configs: if cfg.get("is_auto_mode"): continue diff --git a/surfsense_backend/app/services/llm_service.py b/surfsense_backend/app/services/llm_service.py index 86a9c8556..eadb4dbf8 100644 --- a/surfsense_backend/app/services/llm_service.py +++ b/surfsense_backend/app/services/llm_service.py @@ -15,6 +15,7 @@ from app.services.auto_model_pin_service import ( choose_auto_model_candidate, ) from app.services.llm_router_service import AUTO_MODE_ID, ChatLiteLLMRouter, is_auto_mode +from app.services.model_capabilities import has_capability from app.services.model_resolver import native_connection_from_config, to_litellm from app.services.token_tracking_service import token_tracker @@ -76,7 +77,7 @@ def _legacy_config_connection( api_version: str | None = None, ) -> tuple[str, dict]: cfg = { - "litellm_provider": provider.lower(), + "provider": provider.lower(), "model_name": model_name, "api_key": api_key, "api_base": api_base, @@ -136,12 +137,7 @@ def get_global_connection(connection_id: int) -> dict | None: def _has_capability(model: dict | Model, capability: str) -> bool: - caps = ( - model.get("capabilities", {}) - if isinstance(model, dict) - else model.capabilities or {} - ) - return bool(caps.get(capability)) + return has_capability(model, capability) def _chat_litellm_from_resolved( @@ -420,8 +416,6 @@ async def get_vision_llm( unwrapped — they don't consume premium credit (issue M). """ from app.services.quota_checked_vision_llm import QuotaCheckedVisionLLM - from app.services.vision_llm_router_service import is_vision_auto_mode - try: result = await session.execute( select(SearchSpace).where(SearchSpace.id == search_space_id) @@ -468,7 +462,7 @@ async def get_vision_llm( logger.error(f"No vision LLM configured for search space {search_space_id}") return None - if is_vision_auto_mode(config_id): + if config_id == AUTO_MODE_ID: candidates = await auto_model_candidates( session, search_space_id=search_space_id, diff --git a/surfsense_backend/app/services/model_capabilities.py b/surfsense_backend/app/services/model_capabilities.py new file mode 100644 index 000000000..fb7681f35 --- /dev/null +++ b/surfsense_backend/app/services/model_capabilities.py @@ -0,0 +1,36 @@ +"""Override-aware model capability lookup.""" + +from __future__ import annotations + +from collections.abc import Mapping +from typing import Any + +CAPABILITY_FIELDS = { + "chat": "supports_chat", + "vision": "supports_image_input", + "image_gen": "supports_image_generation", + "tools": "supports_tools", +} + + +def _get_value(model: Any, key: str) -> Any: + if isinstance(model, Mapping): + return model.get(key) + return getattr(model, key, None) + + +def has_capability(model: Any, capability: str) -> bool: + field = CAPABILITY_FIELDS.get(capability) + if field is None: + return False + + override = _get_value(model, "capabilities_override") or {} + if isinstance(override, Mapping) and field in override: + return bool(override[field]) + if isinstance(override, Mapping) and capability in override: + return bool(override[capability]) + + return bool(_get_value(model, field)) + + +__all__ = ["CAPABILITY_FIELDS", "has_capability"] diff --git a/surfsense_backend/app/services/model_connection_service.py b/surfsense_backend/app/services/model_connection_service.py index 5e5b231f9..428af736e 100644 --- a/surfsense_backend/app/services/model_connection_service.py +++ b/surfsense_backend/app/services/model_connection_service.py @@ -11,8 +11,10 @@ from typing import Any import httpx import litellm -from app.db import Connection, ConnectionProtocol, Model, ModelSource +from app.db import Connection, Model, ModelSource from app.services.model_resolver import ensure_v1, to_litellm +from app.services.openrouter_model_normalizer import normalize_openrouter_models +from app.services.provider_registry import Transport, spec_for logger = logging.getLogger(__name__) @@ -41,6 +43,16 @@ def _anthropic_headers(conn: Connection) -> dict[str, str]: return headers +def _base_url_or_default(conn: Connection) -> str | None: + if conn.base_url: + return conn.base_url.rstrip("/") + if conn.provider == "openai": + return "https://api.openai.com/v1" + if conn.provider == "anthropic": + return "https://api.anthropic.com/v1" + return spec_for(conn.provider).default_base_url + + def _docker_hint(url: str | None, exc_or_status: Any) -> str: raw = str(exc_or_status) if not url: @@ -61,32 +73,30 @@ def _docker_hint(url: str | None, exc_or_status: Any) -> str: async def verify_connection(conn: Connection) -> VerifyResult: - if not conn.base_url: + spec = spec_for(conn.provider) + base_url = _base_url_or_default(conn) + if spec.base_url_required and not base_url: return VerifyResult("UNREACHABLE", False, "Base URL is required.") - if conn.protocol == ConnectionProtocol.OLLAMA: - url = f"{conn.base_url.rstrip('/')}/api/version" - elif conn.protocol == ConnectionProtocol.OPENAI_COMPATIBLE: - url = f"{ensure_v1(conn.base_url)}/models" - elif conn.protocol == ConnectionProtocol.ANTHROPIC: - url = f"{conn.base_url.rstrip('/')}/models" + if spec.transport == Transport.OLLAMA and base_url: + url = f"{base_url.rstrip('/')}/api/version" + elif spec.discovery in {"openai_models", "openrouter"} and base_url: + url = f"{ensure_v1(base_url)}/models" + elif spec.discovery == "anthropic_models" and base_url: + url = f"{base_url.rstrip('/')}/models" else: - return VerifyResult("UNREACHABLE", False, "Unsupported connection protocol.") + return VerifyResult("OK", True, "Connection uses provider-native authentication.") try: async with httpx.AsyncClient(timeout=VERIFY_TIMEOUT_SECONDS) as client: - headers = ( - _anthropic_headers(conn) - if conn.protocol == ConnectionProtocol.ANTHROPIC - else _auth_headers(conn) - ) + headers = _anthropic_headers(conn) if spec.auth_style == "x-api-key" else _auth_headers(conn) response = await client.get(url, headers=headers) if response.status_code in (401, 403): return VerifyResult("AUTH_FAILED", False, "Authentication failed.") if response.status_code == 404: - if conn.protocol == ConnectionProtocol.OLLAMA and url.endswith("/v1/models"): + if spec.transport == Transport.OLLAMA and url.endswith("/v1/models"): message = "Ollama native API should not use /v1." - elif conn.protocol == ConnectionProtocol.OPENAI_COMPATIBLE: + elif spec.transport == Transport.OPENAI_COMPATIBLE: message = "OpenAI-compatible servers should expose /v1/models." else: message = "Endpoint returned 404." @@ -94,11 +104,11 @@ async def verify_connection(conn: Connection) -> VerifyResult: response.raise_for_status() return VerifyResult("OK", True, "Connection verified.") except httpx.ConnectError as exc: - return VerifyResult("UNREACHABLE", False, _docker_hint(conn.base_url, exc)) + return VerifyResult("UNREACHABLE", False, _docker_hint(base_url, exc)) except httpx.TimeoutException as exc: return VerifyResult("UNREACHABLE", False, f"Connection timed out: {exc}") except httpx.HTTPError as exc: - return VerifyResult("UNREACHABLE", False, _docker_hint(conn.base_url, exc)) + return VerifyResult("UNREACHABLE", False, _docker_hint(base_url, exc)) async def persist_verification(conn: Connection) -> VerifyResult: @@ -109,123 +119,193 @@ async def persist_verification(conn: Connection) -> VerifyResult: return result -def _litellm_capabilities(model_string: str, model_id: str) -> dict[str, bool]: - capabilities = { - "chat": True, - "vision": False, - "tools": False, - "image_gen": False, - "embedding": False, - } - with contextlib.suppress(Exception): - capabilities["vision"] = bool(litellm.supports_vision(model=model_string)) - with contextlib.suppress(Exception): - capabilities["tools"] = bool(litellm.supports_function_calling(model=model_string)) - try: - info = litellm.model_cost.get(model_string) or litellm.model_cost.get(model_id) or {} - mode = str(info.get("mode") or "") - capabilities["embedding"] = mode == "embedding" - capabilities["image_gen"] = mode in {"image_generation", "image_generation_model"} - except Exception: - pass - return capabilities - - def _allowlist(conn: Connection) -> set[str]: - """Per-connection model-id allowlist stored in ``extra.model_ids``. - - Empty/absent means "no restriction" (discover everything), mirroring - OpenWebUI's behaviour. A non-empty list restricts discovery to those ids — - essential for providers like OpenRouter that expose hundreds of models. - """ raw = (conn.extra or {}).get("model_ids") or [] return {str(item).strip() for item in raw if str(item).strip()} -async def _discover_openai_shaped_models(conn: Connection, base_url: str | None) -> list[dict[str, Any]]: - if not base_url: +def _litellm_info(model_string: str, model_id: str) -> dict[str, Any]: + with contextlib.suppress(Exception): + info = litellm.get_model_info(model=model_string) + if isinstance(info, dict): + return info + return litellm.model_cost.get(model_string) or litellm.model_cost.get(model_id) or {} + + +def _classify_from_litellm(model_string: str, model_id: str) -> dict[str, Any]: + info = _litellm_info(model_string, model_id) + mode = info.get("mode") + supports_image_input = False + supports_tools = False + with contextlib.suppress(Exception): + supports_image_input = bool(litellm.supports_vision(model=model_string)) + with contextlib.suppress(Exception): + supports_tools = bool(litellm.supports_function_calling(model=model_string)) + return { + "supports_chat": mode in (None, "chat", "completion", "responses"), + "max_input_tokens": info.get("max_input_tokens") or info.get("max_tokens"), + "supports_image_input": supports_image_input, + "supports_tools": supports_tools, + "supports_image_generation": mode in {"image_generation", "image_generation_model"}, + } + + +def derive_capabilities(conn: Connection, model_id: str, metadata: dict | None = None) -> dict[str, Any]: + metadata = metadata or {} + spec = spec_for(conn.provider) + model_string, _ = to_litellm(conn, model_id) + facts = _classify_from_litellm(model_string, model_id) + if spec.transport == Transport.OLLAMA: + caps = set(metadata.get("capabilities") or []) + details = metadata.get("details") or {} + facts.update( + { + "supports_chat": "embedding" not in caps, + "supports_image_input": "vision" in caps or facts["supports_image_input"], + "supports_tools": "tools" in caps or facts["supports_tools"], + "supports_image_generation": False, + "max_input_tokens": metadata.get("context_length") + or metadata.get("num_ctx") + or details.get("context_length") + or facts["max_input_tokens"], + } + ) + return facts + + +async def _discover_openai_shaped_models( + conn: Connection, base_url: str | None +) -> list[dict[str, Any]]: + resolved_base_url = base_url or _base_url_or_default(conn) + if not resolved_base_url: return [] - url = f"{ensure_v1(base_url)}/models" + url = f"{ensure_v1(resolved_base_url)}/models" async with httpx.AsyncClient(timeout=DISCOVERY_TIMEOUT_SECONDS) as client: response = await client.get(url, headers=_auth_headers(conn)) response.raise_for_status() - return [ - { - "model_id": item.get("id"), - "display_name": item.get("name") or item.get("id"), - "source": ModelSource.DISCOVERED, - "capabilities": derive_capabilities(conn, item.get("id"), item), - "metadata": item, - } - for item in response.json().get("data", []) - if item.get("id") - ] + + results: list[dict[str, Any]] = [] + for item in response.json().get("data", []): + model_id = item.get("id") + if not model_id: + continue + results.append( + { + "model_id": model_id, + "display_name": item.get("name") or model_id, + "source": ModelSource.DISCOVERED, + **derive_capabilities(conn, model_id, item), + "metadata": item, + } + ) + return results async def _discover_anthropic_models(conn: Connection) -> list[dict[str, Any]]: - if not conn.base_url: + base_url = _base_url_or_default(conn) + if not base_url: return [] - url = f"{conn.base_url.rstrip('/')}/models" + url = f"{base_url.rstrip('/')}/models" async with httpx.AsyncClient(timeout=DISCOVERY_TIMEOUT_SECONDS) as client: response = await client.get(url, headers=_anthropic_headers(conn)) response.raise_for_status() - models = response.json().get("data", []) - return [ - { - "model_id": item.get("id"), - "display_name": item.get("display_name") or item.get("id"), - "source": ModelSource.DISCOVERED, - "capabilities": derive_capabilities(conn, item.get("id"), item), - "metadata": item, - } - for item in models - if item.get("id") - ] + + results: list[dict[str, Any]] = [] + for item in response.json().get("data", []): + model_id = item.get("id") + if not model_id: + continue + results.append( + { + "model_id": model_id, + "display_name": item.get("display_name") or model_id, + "source": ModelSource.DISCOVERED, + **derive_capabilities(conn, model_id, item), + "metadata": item, + } + ) + return results -def derive_capabilities(conn: Connection, model_id: str, metadata: dict | None = None) -> dict[str, bool]: - metadata = metadata or {} - if conn.protocol == ConnectionProtocol.OLLAMA: - caps = metadata.get("capabilities") or [] - capabilities = { - "chat": True, - "vision": "vision" in caps, - "tools": False, - "image_gen": False, - "embedding": "embedding" in caps, - } - return capabilities +async def _ollama_tags_then_show(conn: Connection) -> list[dict[str, Any]]: + if not conn.base_url: + return [] - model_string, _ = to_litellm(conn, model_id) - return _litellm_capabilities(model_string, model_id) + base_url = conn.base_url.rstrip("/") + async with httpx.AsyncClient(timeout=DISCOVERY_TIMEOUT_SECONDS) as client: + response = await client.get(f"{base_url}/api/tags", headers=_auth_headers(conn)) + response.raise_for_status() + models = response.json().get("models", []) + results: list[dict[str, Any]] = [] + for item in models: + model_id = item.get("model") or item.get("name") + if not model_id: + continue + metadata = dict(item) + with contextlib.suppress(Exception): + show_response = await client.post( + f"{base_url}/api/show", + json={"model": model_id}, + headers=_auth_headers(conn), + ) + show_response.raise_for_status() + metadata.update(show_response.json()) + results.append( + { + "model_id": model_id, + "display_name": item.get("name") or model_id, + "source": ModelSource.DISCOVERED, + **derive_capabilities(conn, model_id, metadata), + "metadata": metadata, + } + ) + return results + + +async def _openrouter_models(conn: Connection) -> list[dict[str, Any]]: + base_url = _base_url_or_default(conn) or "https://openrouter.ai/api/v1" + async with httpx.AsyncClient(timeout=DISCOVERY_TIMEOUT_SECONDS) as client: + response = await client.get(f"{ensure_v1(base_url)}/models", headers=_auth_headers(conn)) + response.raise_for_status() + return normalize_openrouter_models(response.json().get("data", [])) + + +def _litellm_static_models(conn: Connection) -> list[dict[str, Any]]: + provider = conn.provider + prefix = spec_for(provider).litellm_prefix or provider + results: list[dict[str, Any]] = [] + for model_string, metadata in litellm.model_cost.items(): + if not isinstance(model_string, str) or not model_string.startswith(f"{prefix}/"): + continue + model_id = model_string.split("/", 1)[1] + results.append( + { + "model_id": model_id, + "display_name": metadata.get("display_name") or model_id, + "source": ModelSource.DISCOVERED, + **_classify_from_litellm(model_string, model_id), + "metadata": metadata, + } + ) + return results async def discover_models(conn: Connection) -> list[dict[str, Any]]: allowlist = _allowlist(conn) + spec = spec_for(conn.provider) - if conn.protocol == ConnectionProtocol.OLLAMA: - url = f"{conn.base_url.rstrip('/')}/api/tags" - async with httpx.AsyncClient(timeout=DISCOVERY_TIMEOUT_SECONDS) as client: - response = await client.get(url, headers=_auth_headers(conn)) - response.raise_for_status() - models = response.json().get("models", []) - results = [ - { - "model_id": item.get("model") or item.get("name"), - "display_name": item.get("name") or item.get("model"), - "source": ModelSource.DISCOVERED, - "capabilities": derive_capabilities(conn, item.get("model") or item.get("name"), item), - "metadata": item, - } - for item in models - if item.get("model") or item.get("name") - ] - elif conn.protocol == ConnectionProtocol.OPENAI_COMPATIBLE: - results = await _discover_openai_shaped_models(conn, conn.base_url) - elif conn.protocol == ConnectionProtocol.ANTHROPIC: + if spec.discovery == "ollama": + results = await _ollama_tags_then_show(conn) + elif spec.discovery == "openrouter": + results = await _openrouter_models(conn) + elif spec.discovery == "anthropic_models": results = await _discover_anthropic_models(conn) + elif spec.discovery == "openai_models": + results = await _discover_openai_shaped_models(conn, conn.base_url) + elif spec.discovery == "static": + results = _litellm_static_models(conn) else: results = [] @@ -246,10 +326,7 @@ async def test_model(conn: Connection, model: Model) -> VerifyResult: except Exception as exc: return VerifyResult("UNREACHABLE", False, str(exc)) - model.capabilities_verified = { - **(model.capabilities_verified or {}), - "chat": True, - } + model.supports_chat = True return VerifyResult("OK", True, "Model test succeeded.") diff --git a/surfsense_backend/app/services/model_list_service.py b/surfsense_backend/app/services/model_list_service.py index 1ef0b0c90..0761d7e4f 100644 --- a/surfsense_backend/app/services/model_list_service.py +++ b/surfsense_backend/app/services/model_list_service.py @@ -12,6 +12,8 @@ from pathlib import Path import httpx +from app.services.openrouter_model_normalizer import normalize_openrouter_models + logger = logging.getLogger(__name__) OPENROUTER_API_URL = "https://openrouter.ai/api/v1/models" @@ -121,26 +123,13 @@ def _process_models(raw_models: list[dict]) -> list[dict]: """ processed: list[dict] = [] - for model in raw_models: - model_id: str = model.get("id", "") - name: str = model.get("name", "") - context_length = model.get("context_length") - + for normalized in normalize_openrouter_models(raw_models): + model_id: str = normalized["model_id"] + name: str = normalized.get("display_name") or model_id + context_length = normalized.get("max_input_tokens") if "/" not in model_id: continue - if not _is_text_output_model(model): - continue - - if not _supports_tool_calling(model): - continue - - if not _has_sufficient_context(model): - continue - - if not _is_allowed_model(model): - continue - provider_slug, model_name = model_id.split("/", 1) context_window = _format_context_length(context_length) diff --git a/surfsense_backend/app/services/model_resolver.py b/surfsense_backend/app/services/model_resolver.py index ffa77a9a2..ae6fd2877 100644 --- a/surfsense_backend/app/services/model_resolver.py +++ b/surfsense_backend/app/services/model_resolver.py @@ -12,9 +12,7 @@ from typing import TYPE_CHECKING, Any if TYPE_CHECKING: from app.db import Connection -PROTOCOL_OLLAMA = "OLLAMA" -PROTOCOL_OPENAI_COMPATIBLE = "OPENAI_COMPATIBLE" -PROTOCOL_ANTHROPIC = "ANTHROPIC" +from app.services.provider_registry import Transport, spec_for def ensure_v1(base_url: str | None) -> str | None: @@ -32,47 +30,25 @@ def _conn_value(conn: Connection | Mapping[str, Any], key: str) -> Any: return getattr(conn, key) -def _protocol_value(protocol: Any) -> str: - return getattr(protocol, "value", str(protocol)) - - -def default_litellm_provider(protocol: Any) -> str: - protocol_value = _protocol_value(protocol) - defaults = { - PROTOCOL_OLLAMA: "ollama_chat", - PROTOCOL_ANTHROPIC: "anthropic", - PROTOCOL_OPENAI_COMPATIBLE: "openai", - } - return defaults.get(protocol_value, "openai") - - -def _execution_api_base(protocol: str, base_url: str | None) -> str | None: - del protocol - if not base_url: - return None - return base_url.rstrip("/") - - def to_litellm( conn: Connection | Mapping[str, Any], model_id: str, ) -> tuple[str, dict[str, Any]]: """Return ``(model_string, litellm_kwargs)`` for any model role.""" - protocol = _protocol_value(_conn_value(conn, "protocol")) + provider = _conn_value(conn, "provider") base_url = _conn_value(conn, "base_url") api_key = _conn_value(conn, "api_key") - litellm_provider = ( - _conn_value(conn, "litellm_provider") or default_litellm_provider(protocol) - ) extra = _conn_value(conn, "extra") or {} + spec = spec_for(provider) kwargs: dict[str, Any] = {} if api_key: kwargs["api_key"] = api_key - model_string = f"{litellm_provider}/{model_id}" if litellm_provider else model_id - api_base = _execution_api_base(protocol, base_url) - if api_base: + prefix = spec.litellm_prefix or str(provider) + model_string = f"{prefix}/{model_id}" if prefix else model_id + if base_url: + api_base = ensure_v1(base_url) if spec.transport == Transport.OPENAI_COMPATIBLE else base_url.rstrip("/") kwargs["api_base"] = api_base if api_version := extra.get("api_version"): @@ -84,11 +60,11 @@ def to_litellm( def native_connection_from_config(config: Mapping[str, Any]) -> dict[str, Any]: """Build an in-memory connection mapping from a global config.""" - protocol = str(config.get("protocol") or PROTOCOL_OPENAI_COMPATIBLE) - litellm_provider = str( - config.get("litellm_provider") + provider = str( + config.get("provider") + or config.get("litellm_provider") or config.get("custom_provider") - or default_litellm_provider(protocol) + or "openai" ) extra: dict[str, Any] = { "litellm_params": config.get("litellm_params") or {}, @@ -96,8 +72,7 @@ def native_connection_from_config(config: Mapping[str, Any]) -> dict[str, Any]: if config.get("api_version"): extra["api_version"] = config.get("api_version") return { - "protocol": protocol, - "litellm_provider": litellm_provider, + "provider": provider, "base_url": config.get("api_base") or None, "api_key": config.get("api_key") or None, "extra": extra, @@ -105,7 +80,6 @@ def native_connection_from_config(config: Mapping[str, Any]) -> dict[str, Any]: __all__ = [ - "default_litellm_provider", "ensure_v1", "native_connection_from_config", "to_litellm", diff --git a/surfsense_backend/app/services/openrouter_integration_service.py b/surfsense_backend/app/services/openrouter_integration_service.py index 6996f0fde..fbb70eb5a 100644 --- a/surfsense_backend/app/services/openrouter_integration_service.py +++ b/surfsense_backend/app/services/openrouter_integration_service.py @@ -29,6 +29,13 @@ from app.services.quality_score import ( aggregate_health, static_score_or, ) +from app.services.openrouter_model_normalizer import ( + is_allowed_model as _shared_is_allowed_model, + is_compatible_provider as _shared_is_compatible_provider, + is_openrouter_image_model, + normalize_openrouter_models, + supports_image_input, +) logger = logging.getLogger(__name__) @@ -292,24 +299,16 @@ def _generate_configs( use_default: bool = settings.get("use_default_system_instructions", True) citations_enabled: bool = settings.get("citations_enabled", True) - text_models = [ - m - for m in raw_models - if _is_text_output_model(m) - and _supports_tool_calling(m) - and _has_sufficient_context(m) - and _is_compatible_provider(m) - and _is_allowed_model(m) - and "/" in m.get("id", "") - ] + text_models = normalize_openrouter_models(raw_models) configs: list[dict] = [] taken: set[int] = set() now_ts = int(time.time()) - for model in text_models: - model_id: str = model["id"] - name: str = model.get("name", model_id) + for normalized in text_models: + model = normalized.get("metadata") or {} + model_id: str = normalized["model_id"] + name: str = normalized.get("display_name") or model_id tier = _openrouter_tier(model) static_q = static_score_or(model, now_ts=now_ts) @@ -323,7 +322,7 @@ def _generate_configs( "seo_enabled": seo_enabled, "seo_slug": None, "quota_reserve_tokens": quota_reserve_tokens, - "litellm_provider": "openrouter", + "provider": "openrouter", "model_name": model_id, "api_key": api_key, "api_base": "https://openrouter.ai/api/v1", @@ -345,7 +344,7 @@ def _generate_configs( # ``stream_new_chat`` as a fail-fast safety net before the # OpenRouter request would otherwise 404 with # ``"No endpoints found that support image input"``. - "supports_image_input": _supports_image_input(model), + "supports_image_input": bool(normalized.get("supports_image_input")), _OPENROUTER_DYNAMIC_MARKER: True, # Auto (Fastest) ranking metadata. ``quality_score`` is initialised # to the static score and gets re-blended with health on the next @@ -403,10 +402,7 @@ def _generate_image_gen_configs( image_models = [ m for m in raw_models - if _is_image_output_model(m) - and _is_compatible_provider(m) - and _is_allowed_model(m) - and "/" in m.get("id", "") + if is_openrouter_image_model(m) ] configs: list[dict] = [] @@ -420,7 +416,7 @@ def _generate_image_gen_configs( "id": _stable_config_id(model_id, id_offset, taken), "name": name, "description": f"{name} via OpenRouter (image generation)", - "litellm_provider": "openrouter", + "provider": "openrouter", "model_name": model_id, "api_key": api_key, "api_base": "https://openrouter.ai/api/v1", @@ -468,9 +464,9 @@ def _generate_vision_llm_configs( vision_models = [ m for m in raw_models - if _is_vision_input_model(m) - and _is_compatible_provider(m) - and _is_allowed_model(m) + if supports_image_input(m) + and _shared_is_compatible_provider(m) + and _shared_is_allowed_model(m) and "/" in m.get("id", "") ] @@ -499,7 +495,7 @@ def _generate_vision_llm_configs( "id": _stable_config_id(model_id, id_offset, taken), "name": name, "description": f"{name} via OpenRouter (vision)", - "litellm_provider": "openrouter", + "provider": "openrouter", "model_name": model_id, "api_key": api_key, "api_base": "https://openrouter.ai/api/v1", @@ -544,11 +540,9 @@ class OpenRouterIntegrationService: # Cached raw catalogue from the most recent fetch. Image / vision # emitters reuse this to avoid a second network call per surface. self._raw_models: list[dict] = [] - # Image / vision config caches (only populated when the matching - # opt-in flag is true on initialize). Refreshed in lockstep with - # the chat catalogue. + # Image config cache (only populated when the matching opt-in flag is + # true on initialize). Refreshed in lockstep with the chat catalogue. self._image_configs: list[dict] = [] - self._vision_configs: list[dict] = [] @classmethod def get_instance(cls) -> "OpenRouterIntegrationService": @@ -583,7 +577,7 @@ class OpenRouterIntegrationService: self._configs_by_id = {c["id"]: c for c in self._configs} self._raw_pricing = _extract_raw_pricing(raw_models) - # Populate image / vision caches when their opt-in flag is set. + # Populate image cache when its opt-in flag is set. # Empty otherwise so the accessors return [] without re-running # filters every refresh. if settings.get("image_generation_enabled"): @@ -595,15 +589,6 @@ class OpenRouterIntegrationService: else: self._image_configs = [] - if settings.get("vision_enabled"): - self._vision_configs = _generate_vision_llm_configs(raw_models, settings) - logger.info( - "OpenRouter integration: vision LLM emission ON (%d models)", - len(self._vision_configs), - ) - else: - self._vision_configs = [] - self._initialized = True tier_counts = self._tier_counts(self._configs) @@ -657,9 +642,9 @@ class OpenRouterIntegrationService: self._configs = new_configs self._configs_by_id = new_by_id - # Image / vision lists are atomic-swapped the same way: filter out + # Image list is atomic-swapped the same way: filter out # the previous dynamic entries from the live config list and append - # the freshly generated ones. No-ops when the opt-in flag is off. + # the freshly generated ones. No-op when the opt-in flag is off. if self._settings.get("image_generation_enabled"): new_image = _generate_image_gen_configs(raw_models, self._settings) static_image = [ @@ -670,16 +655,6 @@ class OpenRouterIntegrationService: app_config.GLOBAL_IMAGE_GEN_CONFIGS = static_image + new_image self._image_configs = new_image - if self._settings.get("vision_enabled"): - new_vision = _generate_vision_llm_configs(raw_models, self._settings) - static_vision = [ - c - for c in app_config.GLOBAL_VISION_LLM_CONFIGS - if not c.get(_OPENROUTER_DYNAMIC_MARKER) - ] - app_config.GLOBAL_VISION_LLM_CONFIGS = static_vision + new_vision - self._vision_configs = new_vision - # Catalogue churn invalidates per-config "recently healthy" credit # earned by the previous turn's preflight. Drop the whole table so # the next turn re-probes against the freshly loaded configs. @@ -701,7 +676,7 @@ class OpenRouterIntegrationService: ) # Re-blend health scores against the freshly fetched catalogue. Also - # re-stamps health for any YAML-curated cfg with litellm_provider=openrouter + # re-stamps health for any YAML-curated cfg with provider=openrouter # so a hand-picked dead OR model is gated like a dynamic one. await self._enrich_health_safely(static_configs + new_configs, log_summary=True) @@ -778,7 +753,7 @@ class OpenRouterIntegrationService: the entire previous cycle's cache for this run. """ or_cfgs = [ - c for c in configs if str(c.get("litellm_provider", "")).lower() == "openrouter" + c for c in configs if str(c.get("provider", "")).lower() == "openrouter" ] if not or_cfgs: return @@ -959,17 +934,6 @@ class OpenRouterIntegrationService: """ return list(self._image_configs) - def get_vision_llm_configs(self) -> list[dict]: - """Return the dynamic OpenRouter vision-LLM configs (empty list - when the ``vision_enabled`` flag is off). - - Each entry exposes ``input_cost_per_token`` / ``output_cost_per_token`` - so ``pricing_registration`` can teach LiteLLM the cost of these - models the same way it does for chat — which keeps the billable - wrapper able to debit accurate micro-USD on a vision call. - """ - return list(self._vision_configs) - def get_raw_pricing(self) -> dict[str, dict[str, str]]: """Return the cached raw OpenRouter pricing map. diff --git a/surfsense_backend/app/services/openrouter_model_normalizer.py b/surfsense_backend/app/services/openrouter_model_normalizer.py new file mode 100644 index 000000000..0b646f933 --- /dev/null +++ b/surfsense_backend/app/services/openrouter_model_normalizer.py @@ -0,0 +1,121 @@ +"""Shared OpenRouter model normalization. + +OpenRouter metadata is richer than generic OpenAI-compatible ``/models`` +responses. Keep all OpenRouter filtering and capability extraction here so +GLOBAL catalogue generation and BYOK discovery agree. +""" + +from __future__ import annotations + +from typing import Any + +from app.db import ModelSource + +MIN_CONTEXT_LENGTH = 100_000 + +EXCLUDED_PROVIDER_SLUGS = {"amazon"} +EXCLUDED_MODEL_IDS: set[str] = { + "openai/gpt-4-1106-preview", + "openai/gpt-4-turbo-preview", + "openai/gpt-4o:extended", + "arcee-ai/virtuoso-large", + "openai/o3-deep-research", + "openai/o4-mini-deep-research", + "openrouter/free", +} +EXCLUDED_MODEL_SUFFIXES: tuple[str, ...] = ("-deep-research",) + + +def is_text_output_model(model: dict[str, Any]) -> bool: + output_mods = model.get("architecture", {}).get("output_modalities", []) + return output_mods == ["text"] + + +def is_image_output_model(model: dict[str, Any]) -> bool: + output_mods = model.get("architecture", {}).get("output_modalities", []) or [] + return "image" in output_mods + + +def supports_image_input(model: dict[str, Any]) -> bool: + input_mods = model.get("architecture", {}).get("input_modalities", []) or [] + return "image" in input_mods + + +def supports_tool_calling(model: dict[str, Any]) -> bool: + supported = model.get("supported_parameters") or [] + return "tools" in supported + + +def has_sufficient_context(model: dict[str, Any]) -> bool: + return int(model.get("context_length") or 0) >= MIN_CONTEXT_LENGTH + + +def is_compatible_provider(model: dict[str, Any]) -> bool: + model_id = str(model.get("id") or "") + slug = model_id.split("/", 1)[0] if "/" in model_id else "" + return slug not in EXCLUDED_PROVIDER_SLUGS + + +def is_allowed_model(model: dict[str, Any]) -> bool: + model_id = str(model.get("id") or "") + if model_id in EXCLUDED_MODEL_IDS: + return False + base_id = model_id.split(":")[0] + return not base_id.endswith(EXCLUDED_MODEL_SUFFIXES) + + +def is_openrouter_chat_model(model: dict[str, Any]) -> bool: + return ( + "/" in str(model.get("id") or "") + and is_text_output_model(model) + and supports_tool_calling(model) + and has_sufficient_context(model) + and is_compatible_provider(model) + and is_allowed_model(model) + ) + + +def is_openrouter_image_model(model: dict[str, Any]) -> bool: + return ( + "/" in str(model.get("id") or "") + and is_image_output_model(model) + and is_compatible_provider(model) + and is_allowed_model(model) + ) + + +def normalize_openrouter_models(raw_models: list[dict[str, Any]]) -> list[dict[str, Any]]: + normalized: list[dict[str, Any]] = [] + for model in raw_models: + if not is_openrouter_chat_model(model): + continue + model_id = str(model.get("id") or "") + normalized.append( + { + "model_id": model_id, + "display_name": model.get("name") or model_id, + "source": ModelSource.DISCOVERED, + "supports_chat": True, + "max_input_tokens": model.get("context_length"), + "supports_image_input": supports_image_input(model), + "supports_tools": supports_tool_calling(model), + "supports_image_generation": False, + "metadata": model, + } + ) + return normalized + + +__all__ = [ + "MIN_CONTEXT_LENGTH", + "has_sufficient_context", + "is_allowed_model", + "is_compatible_provider", + "is_image_output_model", + "is_openrouter_chat_model", + "is_openrouter_image_model", + "is_text_output_model", + "normalize_openrouter_models", + "supports_image_input", + "supports_tool_calling", +] diff --git a/surfsense_backend/app/services/pricing_registration.py b/surfsense_backend/app/services/pricing_registration.py index 6b99fe723..9e4e3b552 100644 --- a/surfsense_backend/app/services/pricing_registration.py +++ b/surfsense_backend/app/services/pricing_registration.py @@ -143,7 +143,7 @@ def _register_chat_shape_configs( sample_keys: list[str] = [] for cfg in configs: - provider = str(cfg.get("litellm_provider") or "").lower() + provider = str(cfg.get("provider") or cfg.get("litellm_provider") or "").lower() model_name = str(cfg.get("model_name") or "").strip() litellm_params = cfg.get("litellm_params") or {} base_model = str(litellm_params.get("base_model") or model_name).strip() @@ -216,9 +216,8 @@ def _register_chat_shape_configs( def register_pricing_from_global_configs() -> None: """Register pricing for every known LLM deployment with LiteLLM. - Walks ``config.GLOBAL_LLM_CONFIGS`` *and* ``config.GLOBAL_VISION_LLM_CONFIGS`` - so vision calls (during indexing) can resolve cost the same way chat - calls do — namely: + Walks ``config.GLOBAL_LLM_CONFIGS`` so chat and vision calls can resolve + cost from the same chat-shaped deployment configs: 1. ``OPENROUTER``: pulls the cached raw pricing from ``OpenRouterIntegrationService`` (populated during its own @@ -245,10 +244,7 @@ def register_pricing_from_global_configs() -> None: from app.config import config as app_config chat_configs: list[dict] = list(getattr(app_config, "GLOBAL_LLM_CONFIGS", []) or []) - vision_configs: list[dict] = list( - getattr(app_config, "GLOBAL_VISION_LLM_CONFIGS", []) or [] - ) - if not chat_configs and not vision_configs: + if not chat_configs: logger.info("[PricingRegistration] no global configs to register") return @@ -267,7 +263,3 @@ def register_pricing_from_global_configs() -> None: if chat_configs: _register_chat_shape_configs(chat_configs, or_pricing=or_pricing, label="chat") - if vision_configs: - _register_chat_shape_configs( - vision_configs, or_pricing=or_pricing, label="vision" - ) diff --git a/surfsense_backend/app/services/provider_capabilities.py b/surfsense_backend/app/services/provider_capabilities.py index 9521ef7a4..fae283ab6 100644 --- a/surfsense_backend/app/services/provider_capabilities.py +++ b/surfsense_backend/app/services/provider_capabilities.py @@ -51,7 +51,7 @@ logger = logging.getLogger(__name__) def _candidate_model_strings( *, - litellm_provider: str | None, + provider: str | None, model_name: str | None, base_model: str | None, custom_provider: str | None, @@ -78,7 +78,7 @@ def _candidate_model_strings( seen.add(key) candidates.append(key) - provider_prefix = custom_provider or litellm_provider + provider_prefix = custom_provider or provider primary_model = base_model or model_name bare_model = model_name @@ -113,7 +113,7 @@ def _candidate_model_strings( def derive_supports_image_input( *, - litellm_provider: str | None = None, + provider: str | None = None, model_name: str | None = None, base_model: str | None = None, custom_provider: str | None = None, @@ -147,7 +147,7 @@ def derive_supports_image_input( return False for model_string, custom_llm_provider in _candidate_model_strings( - litellm_provider=litellm_provider, + provider=provider, model_name=model_name, base_model=base_model, custom_provider=custom_provider, @@ -172,7 +172,7 @@ def derive_supports_image_input( def is_known_text_only_chat_model( *, - litellm_provider: str | None = None, + provider: str | None = None, model_name: str | None = None, base_model: str | None = None, custom_provider: str | None = None, @@ -193,7 +193,7 @@ def is_known_text_only_chat_model( leads to the regression we're fixing here. """ for model_string, custom_llm_provider in _candidate_model_strings( - litellm_provider=litellm_provider, + provider=provider, model_name=model_name, base_model=base_model, custom_provider=custom_provider, diff --git a/surfsense_backend/app/services/provider_registry.py b/surfsense_backend/app/services/provider_registry.py new file mode 100644 index 000000000..871769f11 --- /dev/null +++ b/surfsense_backend/app/services/provider_registry.py @@ -0,0 +1,98 @@ +"""Provider registry for model connections. + +The provider string is the single public identity axis. This registry only +describes providers whose behavior differs from LiteLLM's native default. +""" + +from __future__ import annotations + +from dataclasses import dataclass +from enum import StrEnum +from typing import Literal + + +class Transport(StrEnum): + NATIVE = "NATIVE" + OPENAI_COMPATIBLE = "OPENAI_COMPATIBLE" + OLLAMA = "OLLAMA" + + +DiscoveryKind = Literal[ + "ollama", + "openai_models", + "anthropic_models", + "openrouter", + "static", + "none", +] + +AuthStyle = Literal["bearer", "x-api-key", "none", "native"] + + +@dataclass(frozen=True) +class ProviderSpec: + transport: Transport + litellm_prefix: str | None + discovery: DiscoveryKind + default_base_url: str | None + base_url_required: bool + auth_style: AuthStyle + + +REGISTRY: dict[str, ProviderSpec] = { + "openai": ProviderSpec( + Transport.NATIVE, "openai", "openai_models", None, False, "bearer" + ), + "anthropic": ProviderSpec( + Transport.NATIVE, "anthropic", "anthropic_models", None, False, "x-api-key" + ), + "azure": ProviderSpec(Transport.NATIVE, "azure", "static", None, True, "native"), + "vertex_ai": ProviderSpec( + Transport.NATIVE, "vertex_ai", "static", None, False, "native" + ), + "bedrock": ProviderSpec( + Transport.NATIVE, "bedrock", "static", None, False, "native" + ), + "openrouter": ProviderSpec( + Transport.OPENAI_COMPATIBLE, + "openrouter", + "openrouter", + "https://openrouter.ai/api/v1", + False, + "bearer", + ), + "openai_compatible": ProviderSpec( + Transport.OPENAI_COMPATIBLE, + "openai", + "openai_models", + None, + True, + "bearer", + ), + "lm_studio": ProviderSpec( + Transport.OPENAI_COMPATIBLE, + "openai", + "openai_models", + "http://localhost:1234/v1", + True, + "bearer", + ), + "ollama_chat": ProviderSpec( + Transport.OLLAMA, + "ollama_chat", + "ollama", + "http://localhost:11434", + True, + "none", + ), +} + + +def spec_for(provider: str | None) -> ProviderSpec: + provider_key = (provider or "").strip() + return REGISTRY.get(provider_key) or ProviderSpec( + Transport.NATIVE, provider_key or "openai", "static", None, False, "native" + ) + + +__all__ = ["REGISTRY", "ProviderSpec", "Transport", "spec_for"] diff --git a/surfsense_backend/app/services/quality_score.py b/surfsense_backend/app/services/quality_score.py index 95484439b..9cc9c21ac 100644 --- a/surfsense_backend/app/services/quality_score.py +++ b/surfsense_backend/app/services/quality_score.py @@ -273,7 +273,7 @@ def static_score_yaml(cfg: dict) -> int: listed this model. Pricing / context fall through to lazy ``litellm`` lookups; failures are silent (we just lose those sub-points). """ - provider = str(cfg.get("litellm_provider", "")).lower() + provider = str(cfg.get("provider") or cfg.get("litellm_provider") or "").lower() base = PROVIDER_PRESTIGE_YAML.get(provider, 15) model_name = cfg.get("model_name") or "" diff --git a/surfsense_backend/app/tasks/chat/streaming/flows/new_chat/llm_capability.py b/surfsense_backend/app/tasks/chat/streaming/flows/new_chat/llm_capability.py index f6fcf75d7..69b9f4ab8 100644 --- a/surfsense_backend/app/tasks/chat/streaming/flows/new_chat/llm_capability.py +++ b/surfsense_backend/app/tasks/chat/streaming/flows/new_chat/llm_capability.py @@ -40,7 +40,7 @@ def check_image_input_capability( else None ) if not is_known_text_only_chat_model( - litellm_provider=agent_config.provider, + provider=agent_config.provider, model_name=agent_config.model_name, base_model=agent_base_model, custom_provider=agent_config.custom_provider, diff --git a/surfsense_backend/app/tasks/chat/streaming/flows/shared/llm_bundle.py b/surfsense_backend/app/tasks/chat/streaming/flows/shared/llm_bundle.py index f6870f5fa..cfd50950e 100644 --- a/surfsense_backend/app/tasks/chat/streaming/flows/shared/llm_bundle.py +++ b/surfsense_backend/app/tasks/chat/streaming/flows/shared/llm_bundle.py @@ -22,6 +22,7 @@ from app.agents.chat.runtime.llm_config import ( ) from app.config import config from app.db import Model, SearchSpace +from app.services.model_capabilities import has_capability from app.services.model_resolver import to_litellm @@ -96,7 +97,7 @@ async def load_llm_bundle( model_id=config_id, search_space=search_space, ) - if not model or not (model.capabilities or {}).get("chat"): + if not model or not has_capability(model, "chat"): return ( None, None, @@ -106,12 +107,12 @@ async def load_llm_bundle( agent_config = _agent_config_from_resolved( config_id=config_id, config_name=model.display_name or model.model_id, - provider=model.connection.litellm_provider or "", + provider=model.connection.provider or "", model_name=model.model_id, api_key=model.connection.api_key, api_base=model.connection.base_url, litellm_params=(model.connection.extra or {}).get("litellm_params"), - supports_image_input=bool((model.capabilities or {}).get("vision")), + supports_image_input=has_capability(model, "vision"), billing_tier="free", ) return ( @@ -121,7 +122,7 @@ async def load_llm_bundle( ) global_model = next((m for m in config.GLOBAL_MODELS if m.get("id") == config_id), None) - if not global_model or not (global_model.get("capabilities") or {}).get("chat"): + if not global_model or not has_capability(global_model, "chat"): return None, None, f"Failed to load global chat model with id {config_id}" global_connection = next( ( @@ -137,12 +138,12 @@ async def load_llm_bundle( agent_config = _agent_config_from_resolved( config_id=config_id, config_name=global_model.get("display_name") or global_model.get("model_id"), - provider=global_connection.get("litellm_provider") or "", + provider=global_connection.get("provider") or "", model_name=global_model["model_id"], api_key=global_connection.get("api_key"), api_base=global_connection.get("base_url"), litellm_params=(global_connection.get("extra") or {}).get("litellm_params"), - supports_image_input=bool((global_model.get("capabilities") or {}).get("vision")), + supports_image_input=has_capability(global_model, "vision"), billing_tier=str(global_model.get("billing_tier", "free")).lower(), ) return (