From 3f016421992919a63211df2adbdc13df39278fbc Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Thu, 11 Jun 2026 17:29:55 +0530 Subject: [PATCH] feat(model-connections): enhance model discovery with OpenAI and LiteLLM support --- .../app/services/model_connection_service.py | 93 +++++++++++++++---- .../model-connections-mutation.atoms.ts | 8 +- 2 files changed, 80 insertions(+), 21 deletions(-) diff --git a/surfsense_backend/app/services/model_connection_service.py b/surfsense_backend/app/services/model_connection_service.py index c8d2e8a5a..42a4792a4 100644 --- a/surfsense_backend/app/services/model_connection_service.py +++ b/surfsense_backend/app/services/model_connection_service.py @@ -2,6 +2,7 @@ from __future__ import annotations +import asyncio import contextlib import logging from dataclasses import dataclass @@ -12,7 +13,8 @@ import httpx import litellm from app.db import Connection, ConnectionProtocol, Model, ModelSource -from app.services.model_resolver import ensure_v1, to_litellm +from app.services.model_resolver import NATIVE_PROVIDER_PREFIX, ensure_v1, to_litellm +from app.services.provider_api_base import resolve_api_base logger = logging.getLogger(__name__) @@ -133,6 +135,63 @@ def _allowlist(conn: Connection) -> set[str]: return {str(item).strip() for item in raw if str(item).strip()} +async def _discover_openai_shaped_models(conn: Connection, base_url: str | None) -> list[dict[str, Any]]: + if not base_url: + return [] + + url = f"{ensure_v1(base_url)}/models" + async with httpx.AsyncClient(timeout=DISCOVERY_TIMEOUT_SECONDS) as client: + response = await client.get(url, headers=_auth_headers(conn)) + response.raise_for_status() + return [ + { + "model_id": item.get("id"), + "display_name": item.get("name") or item.get("id"), + "source": ModelSource.DISCOVERED, + "capabilities": derive_capabilities(conn, item.get("id"), item), + "metadata": item, + } + for item in response.json().get("data", []) + if item.get("id") + ] + + +def _litellm_valid_model_ids(provider: str, api_key: str | None) -> list[str]: + if not api_key: + return [] + + try: + models = litellm.get_valid_models( + check_provider_endpoint=True, + custom_llm_provider=provider, + api_key=api_key, + ) + except Exception as exc: + logger.warning("LiteLLM model discovery failed for provider %s: %s", provider, exc) + return [] + + provider_prefix = f"{provider}/" + return [ + model.removeprefix(provider_prefix) + for model in models + if isinstance(model, str) and model.strip() + ] + + +async def _discover_litellm_native_models(conn: Connection, provider: str) -> list[dict[str, Any]]: + model_ids = await asyncio.to_thread(_litellm_valid_model_ids, provider, conn.api_key) + return [ + { + "model_id": model_id, + "display_name": model_id, + "source": ModelSource.DISCOVERED, + "capabilities": derive_capabilities(conn, model_id), + "metadata": {}, + } + for model_id in model_ids + ] + + def derive_capabilities(conn: Connection, model_id: str, metadata: dict | None = None) -> dict[str, bool]: metadata = metadata or {} if conn.protocol == ConnectionProtocol.OLLAMA: @@ -171,25 +230,21 @@ async def discover_models(conn: Connection) -> list[dict[str, Any]]: if item.get("model") or item.get("name") ] elif conn.protocol == ConnectionProtocol.OPENAI_COMPATIBLE: - url = f"{ensure_v1(conn.base_url)}/models" - async with httpx.AsyncClient(timeout=DISCOVERY_TIMEOUT_SECONDS) as client: - response = await client.get(url, headers=_auth_headers(conn)) - response.raise_for_status() - models = response.json().get("data", []) - results = [ - { - "model_id": item.get("id"), - "display_name": item.get("name") or item.get("id"), - "source": ModelSource.DISCOVERED, - "capabilities": derive_capabilities(conn, item.get("id"), item), - "metadata": item, - } - for item in models - if item.get("id") - ] + results = await _discover_openai_shaped_models(conn, conn.base_url) else: - # Native providers rely on curated/global catalog entries or manual rows. - return [] + provider_key = (conn.native_provider or "").upper() + provider = NATIVE_PROVIDER_PREFIX.get(provider_key, provider_key.lower()) + api_base = resolve_api_base( + provider=provider_key, + provider_prefix=provider, + config_api_base=conn.base_url, + ) + if api_base: + results = await _discover_openai_shaped_models(conn, api_base) + elif provider: + results = await _discover_litellm_native_models(conn, provider) + else: + results = [] if allowlist: results = [item for item in results if item["model_id"] in allowlist] diff --git a/surfsense_web/atoms/model-connections/model-connections-mutation.atoms.ts b/surfsense_web/atoms/model-connections/model-connections-mutation.atoms.ts index 612216bf2..76289e60d 100644 --- a/surfsense_web/atoms/model-connections/model-connections-mutation.atoms.ts +++ b/surfsense_web/atoms/model-connections/model-connections-mutation.atoms.ts @@ -90,8 +90,12 @@ export const discoverConnectionModelsMutationAtom = atomWithMutation((get) => { return { mutationKey: ["model-connections", "discover"], mutationFn: (id: number) => modelConnectionsApiService.discoverModels(id), - onSuccess: () => { - toast.success("Models discovered"); + onSuccess: (models) => { + toast.success( + models.length + ? `${models.length} models discovered` + : "No models found for this connection" + ); invalidateModelConnections(searchSpaceId); }, onError: (error: Error) => toast.error(error.message || "Failed to discover models"),