mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-06-12 20:45:20 +02:00
feat(model-connections): enhance model discovery with OpenAI and LiteLLM support
This commit is contained in:
parent
50c816c81c
commit
3f01642199
2 changed files with 80 additions and 21 deletions
|
|
@ -2,6 +2,7 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
|
|
@ -12,7 +13,8 @@ import httpx
|
|||
import litellm
|
||||
|
||||
from app.db import Connection, ConnectionProtocol, Model, ModelSource
|
||||
from app.services.model_resolver import ensure_v1, to_litellm
|
||||
from app.services.model_resolver import NATIVE_PROVIDER_PREFIX, ensure_v1, to_litellm
|
||||
from app.services.provider_api_base import resolve_api_base
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -133,6 +135,63 @@ def _allowlist(conn: Connection) -> set[str]:
|
|||
return {str(item).strip() for item in raw if str(item).strip()}
|
||||
|
||||
|
||||
async def _discover_openai_shaped_models(conn: Connection, base_url: str | None) -> list[dict[str, Any]]:
|
||||
if not base_url:
|
||||
return []
|
||||
|
||||
url = f"{ensure_v1(base_url)}/models"
|
||||
async with httpx.AsyncClient(timeout=DISCOVERY_TIMEOUT_SECONDS) as client:
|
||||
response = await client.get(url, headers=_auth_headers(conn))
|
||||
response.raise_for_status()
|
||||
return [
|
||||
{
|
||||
"model_id": item.get("id"),
|
||||
"display_name": item.get("name") or item.get("id"),
|
||||
"source": ModelSource.DISCOVERED,
|
||||
"capabilities": derive_capabilities(conn, item.get("id"), item),
|
||||
"metadata": item,
|
||||
}
|
||||
for item in response.json().get("data", [])
|
||||
if item.get("id")
|
||||
]
|
||||
|
||||
|
||||
def _litellm_valid_model_ids(provider: str, api_key: str | None) -> list[str]:
|
||||
if not api_key:
|
||||
return []
|
||||
|
||||
try:
|
||||
models = litellm.get_valid_models(
|
||||
check_provider_endpoint=True,
|
||||
custom_llm_provider=provider,
|
||||
api_key=api_key,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning("LiteLLM model discovery failed for provider %s: %s", provider, exc)
|
||||
return []
|
||||
|
||||
provider_prefix = f"{provider}/"
|
||||
return [
|
||||
model.removeprefix(provider_prefix)
|
||||
for model in models
|
||||
if isinstance(model, str) and model.strip()
|
||||
]
|
||||
|
||||
|
||||
async def _discover_litellm_native_models(conn: Connection, provider: str) -> list[dict[str, Any]]:
|
||||
model_ids = await asyncio.to_thread(_litellm_valid_model_ids, provider, conn.api_key)
|
||||
return [
|
||||
{
|
||||
"model_id": model_id,
|
||||
"display_name": model_id,
|
||||
"source": ModelSource.DISCOVERED,
|
||||
"capabilities": derive_capabilities(conn, model_id),
|
||||
"metadata": {},
|
||||
}
|
||||
for model_id in model_ids
|
||||
]
|
||||
|
||||
|
||||
def derive_capabilities(conn: Connection, model_id: str, metadata: dict | None = None) -> dict[str, bool]:
|
||||
metadata = metadata or {}
|
||||
if conn.protocol == ConnectionProtocol.OLLAMA:
|
||||
|
|
@ -171,25 +230,21 @@ async def discover_models(conn: Connection) -> list[dict[str, Any]]:
|
|||
if item.get("model") or item.get("name")
|
||||
]
|
||||
elif conn.protocol == ConnectionProtocol.OPENAI_COMPATIBLE:
|
||||
url = f"{ensure_v1(conn.base_url)}/models"
|
||||
async with httpx.AsyncClient(timeout=DISCOVERY_TIMEOUT_SECONDS) as client:
|
||||
response = await client.get(url, headers=_auth_headers(conn))
|
||||
response.raise_for_status()
|
||||
models = response.json().get("data", [])
|
||||
results = [
|
||||
{
|
||||
"model_id": item.get("id"),
|
||||
"display_name": item.get("name") or item.get("id"),
|
||||
"source": ModelSource.DISCOVERED,
|
||||
"capabilities": derive_capabilities(conn, item.get("id"), item),
|
||||
"metadata": item,
|
||||
}
|
||||
for item in models
|
||||
if item.get("id")
|
||||
]
|
||||
results = await _discover_openai_shaped_models(conn, conn.base_url)
|
||||
else:
|
||||
# Native providers rely on curated/global catalog entries or manual rows.
|
||||
return []
|
||||
provider_key = (conn.native_provider or "").upper()
|
||||
provider = NATIVE_PROVIDER_PREFIX.get(provider_key, provider_key.lower())
|
||||
api_base = resolve_api_base(
|
||||
provider=provider_key,
|
||||
provider_prefix=provider,
|
||||
config_api_base=conn.base_url,
|
||||
)
|
||||
if api_base:
|
||||
results = await _discover_openai_shaped_models(conn, api_base)
|
||||
elif provider:
|
||||
results = await _discover_litellm_native_models(conn, provider)
|
||||
else:
|
||||
results = []
|
||||
|
||||
if allowlist:
|
||||
results = [item for item in results if item["model_id"] in allowlist]
|
||||
|
|
|
|||
|
|
@ -90,8 +90,12 @@ export const discoverConnectionModelsMutationAtom = atomWithMutation((get) => {
|
|||
return {
|
||||
mutationKey: ["model-connections", "discover"],
|
||||
mutationFn: (id: number) => modelConnectionsApiService.discoverModels(id),
|
||||
onSuccess: () => {
|
||||
toast.success("Models discovered");
|
||||
onSuccess: (models) => {
|
||||
toast.success(
|
||||
models.length
|
||||
? `${models.length} models discovered`
|
||||
: "No models found for this connection"
|
||||
);
|
||||
invalidateModelConnections(searchSpaceId);
|
||||
},
|
||||
onError: (error: Error) => toast.error(error.message || "Failed to discover models"),
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue