feat(model-connections): enhance model connection functionality with preview and selection features

This commit is contained in:
Anish Sarkar 2026-06-12 22:41:21 +05:30
parent 356f0e56c5
commit 407f2a9612
20 changed files with 630 additions and 429 deletions

View file

@ -21,10 +21,12 @@ from app.schemas import (
ConnectionRead,
ConnectionUpdate,
ModelCreate,
ModelPreviewRead,
ModelProviderRead,
ModelRead,
ModelRolesRead,
ModelRolesUpdate,
ModelSelection,
ModelsBulkUpdate,
ModelUpdate,
VerifyConnectionResponse,
@ -48,6 +50,21 @@ def _model_read(model: Model | dict) -> ModelRead:
return ModelRead.model_validate(model)
def _preview_model_read(item: dict) -> ModelPreviewRead:
return ModelPreviewRead(
model_id=item["model_id"],
display_name=item.get("display_name"),
source=item.get("source", ModelSource.DISCOVERED),
supports_chat=item.get("supports_chat"),
max_input_tokens=item.get("max_input_tokens"),
supports_image_input=item.get("supports_image_input"),
supports_tools=item.get("supports_tools"),
supports_image_generation=item.get("supports_image_generation"),
enabled=item.get("enabled", False),
metadata=item.get("metadata") or item.get("catalog") or {},
)
def _connection_read(conn: Connection | dict, models: list[Model | dict] | None = None) -> ConnectionRead:
if isinstance(conn, dict):
payload = {
@ -86,6 +103,25 @@ def _apply_model_facts(model: Model, facts: dict) -> None:
model.supports_image_generation = facts.get("supports_image_generation")
def _selection_to_model(conn: Connection, selection: ModelSelection) -> Model:
source = (
selection.source
if isinstance(selection.source, ModelSource)
else ModelSource(selection.source)
)
model = Model(
connection_id=conn.id,
model_id=selection.model_id.strip(),
display_name=selection.display_name,
source=source,
capabilities_override={},
enabled=selection.enabled,
catalog=selection.metadata,
)
_apply_model_facts(model, selection.model_dump())
return model
def _default_model_for(models: list[Model], capability: str) -> int | None:
for model in models:
if model.enabled and has_capability(model, capability):
@ -226,7 +262,7 @@ async def create_connection(
Permission.LLM_CONFIGS_CREATE.value,
"You don't have permission to create model connections in this search space",
)
payload = data.model_dump(exclude={"search_space_id"})
payload = data.model_dump(exclude={"search_space_id", "models"})
conn = Connection(
**payload,
@ -234,9 +270,51 @@ async def create_connection(
user_id=user.id,
)
session.add(conn)
await session.flush()
seen_model_ids: set[str] = set()
for selection in data.models:
model_id = selection.model_id.strip()
if not model_id or model_id in seen_model_ids:
continue
seen_model_ids.add(model_id)
session.add(_selection_to_model(conn, selection))
await session.commit()
await session.refresh(conn)
return _connection_read(conn, [])
conn = await _load_connection(session, conn.id)
await _default_unset_roles(session, conn, list(conn.models))
await session.commit()
conn = await _load_connection(session, conn.id)
return _connection_read(conn, list(conn.models))
@router.post("/model-connections/discover-preview", response_model=list[ModelPreviewRead])
async def preview_connection_models(
data: ConnectionCreate,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
if data.scope == ConnectionScope.SEARCH_SPACE and data.search_space_id is not None:
await check_permission(
session,
user,
data.search_space_id,
Permission.LLM_CONFIGS_CREATE.value,
"You don't have permission to create model connections in this search space",
)
draft = Connection(
provider=data.provider,
base_url=data.base_url,
api_key=data.api_key,
extra=data.extra or {},
scope=data.scope,
enabled=data.enabled,
search_space_id=data.search_space_id if data.scope == ConnectionScope.SEARCH_SPACE else None,
user_id=user.id,
)
discovered = await discover_models(draft)
return [_preview_model_read(item) for item in discovered]
@router.put("/model-connections/{connection_id}", response_model=ConnectionRead)

View file

@ -49,10 +49,12 @@ from .model_connections import (
ConnectionRead,
ConnectionUpdate,
ModelCreate,
ModelPreviewRead,
ModelProviderRead,
ModelRead,
ModelRolesRead,
ModelRolesUpdate,
ModelSelection,
ModelsBulkUpdate,
ModelUpdate,
VerifyConnectionResponse,

View file

@ -48,6 +48,32 @@ class ConnectionRead(BaseModel):
model_config = ConfigDict(from_attributes=True)
class ModelSelection(BaseModel):
model_id: str = Field(..., max_length=255)
display_name: str | None = Field(None, max_length=255)
source: ModelSource | str = ModelSource.DISCOVERED
supports_chat: bool | None = None
max_input_tokens: int | None = None
supports_image_input: bool | None = None
supports_tools: bool | None = None
supports_image_generation: bool | None = None
enabled: bool = False
metadata: dict[str, Any] = Field(default_factory=dict)
class ModelPreviewRead(BaseModel):
model_id: str
display_name: str | None = None
source: ModelSource | str = ModelSource.DISCOVERED
supports_chat: bool | None = None
max_input_tokens: int | None = None
supports_image_input: bool | None = None
supports_tools: bool | None = None
supports_image_generation: bool | None = None
enabled: bool = False
metadata: dict[str, Any] = Field(default_factory=dict)
class ConnectionCreate(BaseModel):
provider: str = Field(..., max_length=100)
base_url: str | None = Field(None, max_length=500)
@ -56,6 +82,7 @@ class ConnectionCreate(BaseModel):
scope: ConnectionScope = ConnectionScope.SEARCH_SPACE
search_space_id: int | None = None
enabled: bool = True
models: list[ModelSelection] = Field(default_factory=list)
class ConnectionUpdate(BaseModel):

View file

@ -8,6 +8,7 @@ from dataclasses import dataclass
from datetime import UTC, datetime
from typing import Any
import anyio
import httpx
import litellm
@ -292,6 +293,48 @@ def _litellm_static_models(conn: Connection) -> list[dict[str, Any]]:
return results
async def _discover_bedrock_models(conn: Connection) -> list[dict[str, Any]]:
params = (conn.extra or {}).get("litellm_params", {})
region_name = params.get("aws_region_name")
if not region_name:
return []
def list_models() -> list[dict[str, Any]]:
import boto3
client_kwargs: dict[str, str] = {"region_name": region_name}
if params.get("aws_access_key_id"):
client_kwargs["aws_access_key_id"] = params["aws_access_key_id"]
if params.get("aws_secret_access_key"):
client_kwargs["aws_secret_access_key"] = params["aws_secret_access_key"]
client = boto3.client("bedrock", **client_kwargs)
response = client.list_foundation_models()
results: list[dict[str, Any]] = []
for item in response.get("modelSummaries", []):
model_id = item.get("modelId")
if not model_id:
continue
input_modalities = set(item.get("inputModalities") or [])
output_modalities = set(item.get("outputModalities") or [])
results.append(
{
"model_id": model_id,
"display_name": item.get("modelName") or model_id,
"source": ModelSource.DISCOVERED,
"supports_chat": "TEXT" in input_modalities and "TEXT" in output_modalities,
"supports_image_input": "IMAGE" in input_modalities,
"supports_tools": None,
"supports_image_generation": "IMAGE" in output_modalities,
"max_input_tokens": None,
"metadata": item,
}
)
return results
return await anyio.to_thread.run_sync(list_models)
async def discover_models(conn: Connection) -> list[dict[str, Any]]:
allowlist = _allowlist(conn)
spec = spec_for(conn.provider)
@ -304,6 +347,8 @@ async def discover_models(conn: Connection) -> list[dict[str, Any]]:
results = await _discover_anthropic_models(conn)
elif spec.discovery == "openai_models":
results = await _discover_openai_shaped_models(conn, conn.base_url)
elif spec.discovery == "bedrock_models":
results = await _discover_bedrock_models(conn)
elif spec.discovery == "static":
results = _litellm_static_models(conn)
else:

View file

@ -21,6 +21,7 @@ DiscoveryKind = Literal[
"ollama",
"openai_models",
"anthropic_models",
"bedrock_models",
"openrouter",
"static",
"none",
@ -51,7 +52,7 @@ REGISTRY: dict[str, ProviderSpec] = {
Transport.NATIVE, "vertex_ai", "static", None, False, "native"
),
"bedrock": ProviderSpec(
Transport.NATIVE, "bedrock", "static", None, False, "native"
Transport.NATIVE, "bedrock", "bedrock_models", None, False, "native"
),
"openrouter": ProviderSpec(
Transport.OPENAI_COMPATIBLE,