feat(model-connections): enhance model discovery with OpenAI and LiteLLM support

This commit is contained in:
Anish Sarkar 2026-06-11 17:29:55 +05:30
parent 50c816c81c
commit 3f01642199
2 changed files with 80 additions and 21 deletions

View file

@ -2,6 +2,7 @@
from __future__ import annotations
import asyncio
import contextlib
import logging
from dataclasses import dataclass
@ -12,7 +13,8 @@ import httpx
import litellm
from app.db import Connection, ConnectionProtocol, Model, ModelSource
from app.services.model_resolver import ensure_v1, to_litellm
from app.services.model_resolver import NATIVE_PROVIDER_PREFIX, ensure_v1, to_litellm
from app.services.provider_api_base import resolve_api_base
logger = logging.getLogger(__name__)
@ -133,6 +135,63 @@ def _allowlist(conn: Connection) -> set[str]:
return {str(item).strip() for item in raw if str(item).strip()}
async def _discover_openai_shaped_models(conn: Connection, base_url: str | None) -> list[dict[str, Any]]:
if not base_url:
return []
url = f"{ensure_v1(base_url)}/models"
async with httpx.AsyncClient(timeout=DISCOVERY_TIMEOUT_SECONDS) as client:
response = await client.get(url, headers=_auth_headers(conn))
response.raise_for_status()
return [
{
"model_id": item.get("id"),
"display_name": item.get("name") or item.get("id"),
"source": ModelSource.DISCOVERED,
"capabilities": derive_capabilities(conn, item.get("id"), item),
"metadata": item,
}
for item in response.json().get("data", [])
if item.get("id")
]
def _litellm_valid_model_ids(provider: str, api_key: str | None) -> list[str]:
if not api_key:
return []
try:
models = litellm.get_valid_models(
check_provider_endpoint=True,
custom_llm_provider=provider,
api_key=api_key,
)
except Exception as exc:
logger.warning("LiteLLM model discovery failed for provider %s: %s", provider, exc)
return []
provider_prefix = f"{provider}/"
return [
model.removeprefix(provider_prefix)
for model in models
if isinstance(model, str) and model.strip()
]
async def _discover_litellm_native_models(conn: Connection, provider: str) -> list[dict[str, Any]]:
model_ids = await asyncio.to_thread(_litellm_valid_model_ids, provider, conn.api_key)
return [
{
"model_id": model_id,
"display_name": model_id,
"source": ModelSource.DISCOVERED,
"capabilities": derive_capabilities(conn, model_id),
"metadata": {},
}
for model_id in model_ids
]
def derive_capabilities(conn: Connection, model_id: str, metadata: dict | None = None) -> dict[str, bool]:
metadata = metadata or {}
if conn.protocol == ConnectionProtocol.OLLAMA:
@ -171,25 +230,21 @@ async def discover_models(conn: Connection) -> list[dict[str, Any]]:
if item.get("model") or item.get("name")
]
elif conn.protocol == ConnectionProtocol.OPENAI_COMPATIBLE:
url = f"{ensure_v1(conn.base_url)}/models"
async with httpx.AsyncClient(timeout=DISCOVERY_TIMEOUT_SECONDS) as client:
response = await client.get(url, headers=_auth_headers(conn))
response.raise_for_status()
models = response.json().get("data", [])
results = [
{
"model_id": item.get("id"),
"display_name": item.get("name") or item.get("id"),
"source": ModelSource.DISCOVERED,
"capabilities": derive_capabilities(conn, item.get("id"), item),
"metadata": item,
}
for item in models
if item.get("id")
]
results = await _discover_openai_shaped_models(conn, conn.base_url)
else:
# Native providers rely on curated/global catalog entries or manual rows.
return []
provider_key = (conn.native_provider or "").upper()
provider = NATIVE_PROVIDER_PREFIX.get(provider_key, provider_key.lower())
api_base = resolve_api_base(
provider=provider_key,
provider_prefix=provider,
config_api_base=conn.base_url,
)
if api_base:
results = await _discover_openai_shaped_models(conn, api_base)
elif provider:
results = await _discover_litellm_native_models(conn, provider)
else:
results = []
if allowlist:
results = [item for item in results if item["model_id"] in allowlist]

View file

@ -90,8 +90,12 @@ export const discoverConnectionModelsMutationAtom = atomWithMutation((get) => {
return {
mutationKey: ["model-connections", "discover"],
mutationFn: (id: number) => modelConnectionsApiService.discoverModels(id),
onSuccess: () => {
toast.success("Models discovered");
onSuccess: (models) => {
toast.success(
models.length
? `${models.length} models discovered`
: "No models found for this connection"
);
invalidateModelConnections(searchSpaceId);
},
onError: (error: Error) => toast.error(error.message || "Failed to discover models"),