mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-06-20 21:18:13 +02:00
feat(model-connections): enhance model connection functionality with preview and selection features
This commit is contained in:
parent
356f0e56c5
commit
407f2a9612
20 changed files with 630 additions and 429 deletions
|
|
@ -21,10 +21,12 @@ from app.schemas import (
|
|||
ConnectionRead,
|
||||
ConnectionUpdate,
|
||||
ModelCreate,
|
||||
ModelPreviewRead,
|
||||
ModelProviderRead,
|
||||
ModelRead,
|
||||
ModelRolesRead,
|
||||
ModelRolesUpdate,
|
||||
ModelSelection,
|
||||
ModelsBulkUpdate,
|
||||
ModelUpdate,
|
||||
VerifyConnectionResponse,
|
||||
|
|
@ -48,6 +50,21 @@ def _model_read(model: Model | dict) -> ModelRead:
|
|||
return ModelRead.model_validate(model)
|
||||
|
||||
|
||||
def _preview_model_read(item: dict) -> ModelPreviewRead:
|
||||
return ModelPreviewRead(
|
||||
model_id=item["model_id"],
|
||||
display_name=item.get("display_name"),
|
||||
source=item.get("source", ModelSource.DISCOVERED),
|
||||
supports_chat=item.get("supports_chat"),
|
||||
max_input_tokens=item.get("max_input_tokens"),
|
||||
supports_image_input=item.get("supports_image_input"),
|
||||
supports_tools=item.get("supports_tools"),
|
||||
supports_image_generation=item.get("supports_image_generation"),
|
||||
enabled=item.get("enabled", False),
|
||||
metadata=item.get("metadata") or item.get("catalog") or {},
|
||||
)
|
||||
|
||||
|
||||
def _connection_read(conn: Connection | dict, models: list[Model | dict] | None = None) -> ConnectionRead:
|
||||
if isinstance(conn, dict):
|
||||
payload = {
|
||||
|
|
@ -86,6 +103,25 @@ def _apply_model_facts(model: Model, facts: dict) -> None:
|
|||
model.supports_image_generation = facts.get("supports_image_generation")
|
||||
|
||||
|
||||
def _selection_to_model(conn: Connection, selection: ModelSelection) -> Model:
|
||||
source = (
|
||||
selection.source
|
||||
if isinstance(selection.source, ModelSource)
|
||||
else ModelSource(selection.source)
|
||||
)
|
||||
model = Model(
|
||||
connection_id=conn.id,
|
||||
model_id=selection.model_id.strip(),
|
||||
display_name=selection.display_name,
|
||||
source=source,
|
||||
capabilities_override={},
|
||||
enabled=selection.enabled,
|
||||
catalog=selection.metadata,
|
||||
)
|
||||
_apply_model_facts(model, selection.model_dump())
|
||||
return model
|
||||
|
||||
|
||||
def _default_model_for(models: list[Model], capability: str) -> int | None:
|
||||
for model in models:
|
||||
if model.enabled and has_capability(model, capability):
|
||||
|
|
@ -226,7 +262,7 @@ async def create_connection(
|
|||
Permission.LLM_CONFIGS_CREATE.value,
|
||||
"You don't have permission to create model connections in this search space",
|
||||
)
|
||||
payload = data.model_dump(exclude={"search_space_id"})
|
||||
payload = data.model_dump(exclude={"search_space_id", "models"})
|
||||
|
||||
conn = Connection(
|
||||
**payload,
|
||||
|
|
@ -234,9 +270,51 @@ async def create_connection(
|
|||
user_id=user.id,
|
||||
)
|
||||
session.add(conn)
|
||||
await session.flush()
|
||||
|
||||
seen_model_ids: set[str] = set()
|
||||
for selection in data.models:
|
||||
model_id = selection.model_id.strip()
|
||||
if not model_id or model_id in seen_model_ids:
|
||||
continue
|
||||
seen_model_ids.add(model_id)
|
||||
session.add(_selection_to_model(conn, selection))
|
||||
|
||||
await session.commit()
|
||||
await session.refresh(conn)
|
||||
return _connection_read(conn, [])
|
||||
conn = await _load_connection(session, conn.id)
|
||||
await _default_unset_roles(session, conn, list(conn.models))
|
||||
await session.commit()
|
||||
conn = await _load_connection(session, conn.id)
|
||||
return _connection_read(conn, list(conn.models))
|
||||
|
||||
|
||||
@router.post("/model-connections/discover-preview", response_model=list[ModelPreviewRead])
|
||||
async def preview_connection_models(
|
||||
data: ConnectionCreate,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
if data.scope == ConnectionScope.SEARCH_SPACE and data.search_space_id is not None:
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
data.search_space_id,
|
||||
Permission.LLM_CONFIGS_CREATE.value,
|
||||
"You don't have permission to create model connections in this search space",
|
||||
)
|
||||
|
||||
draft = Connection(
|
||||
provider=data.provider,
|
||||
base_url=data.base_url,
|
||||
api_key=data.api_key,
|
||||
extra=data.extra or {},
|
||||
scope=data.scope,
|
||||
enabled=data.enabled,
|
||||
search_space_id=data.search_space_id if data.scope == ConnectionScope.SEARCH_SPACE else None,
|
||||
user_id=user.id,
|
||||
)
|
||||
discovered = await discover_models(draft)
|
||||
return [_preview_model_read(item) for item in discovered]
|
||||
|
||||
|
||||
@router.put("/model-connections/{connection_id}", response_model=ConnectionRead)
|
||||
|
|
|
|||
|
|
@ -49,10 +49,12 @@ from .model_connections import (
|
|||
ConnectionRead,
|
||||
ConnectionUpdate,
|
||||
ModelCreate,
|
||||
ModelPreviewRead,
|
||||
ModelProviderRead,
|
||||
ModelRead,
|
||||
ModelRolesRead,
|
||||
ModelRolesUpdate,
|
||||
ModelSelection,
|
||||
ModelsBulkUpdate,
|
||||
ModelUpdate,
|
||||
VerifyConnectionResponse,
|
||||
|
|
|
|||
|
|
@ -48,6 +48,32 @@ class ConnectionRead(BaseModel):
|
|||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class ModelSelection(BaseModel):
|
||||
model_id: str = Field(..., max_length=255)
|
||||
display_name: str | None = Field(None, max_length=255)
|
||||
source: ModelSource | str = ModelSource.DISCOVERED
|
||||
supports_chat: bool | None = None
|
||||
max_input_tokens: int | None = None
|
||||
supports_image_input: bool | None = None
|
||||
supports_tools: bool | None = None
|
||||
supports_image_generation: bool | None = None
|
||||
enabled: bool = False
|
||||
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class ModelPreviewRead(BaseModel):
|
||||
model_id: str
|
||||
display_name: str | None = None
|
||||
source: ModelSource | str = ModelSource.DISCOVERED
|
||||
supports_chat: bool | None = None
|
||||
max_input_tokens: int | None = None
|
||||
supports_image_input: bool | None = None
|
||||
supports_tools: bool | None = None
|
||||
supports_image_generation: bool | None = None
|
||||
enabled: bool = False
|
||||
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class ConnectionCreate(BaseModel):
|
||||
provider: str = Field(..., max_length=100)
|
||||
base_url: str | None = Field(None, max_length=500)
|
||||
|
|
@ -56,6 +82,7 @@ class ConnectionCreate(BaseModel):
|
|||
scope: ConnectionScope = ConnectionScope.SEARCH_SPACE
|
||||
search_space_id: int | None = None
|
||||
enabled: bool = True
|
||||
models: list[ModelSelection] = Field(default_factory=list)
|
||||
|
||||
|
||||
class ConnectionUpdate(BaseModel):
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ from dataclasses import dataclass
|
|||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
|
||||
import anyio
|
||||
import httpx
|
||||
import litellm
|
||||
|
||||
|
|
@ -292,6 +293,48 @@ def _litellm_static_models(conn: Connection) -> list[dict[str, Any]]:
|
|||
return results
|
||||
|
||||
|
||||
async def _discover_bedrock_models(conn: Connection) -> list[dict[str, Any]]:
|
||||
params = (conn.extra or {}).get("litellm_params", {})
|
||||
region_name = params.get("aws_region_name")
|
||||
if not region_name:
|
||||
return []
|
||||
|
||||
def list_models() -> list[dict[str, Any]]:
|
||||
import boto3
|
||||
|
||||
client_kwargs: dict[str, str] = {"region_name": region_name}
|
||||
if params.get("aws_access_key_id"):
|
||||
client_kwargs["aws_access_key_id"] = params["aws_access_key_id"]
|
||||
if params.get("aws_secret_access_key"):
|
||||
client_kwargs["aws_secret_access_key"] = params["aws_secret_access_key"]
|
||||
|
||||
client = boto3.client("bedrock", **client_kwargs)
|
||||
response = client.list_foundation_models()
|
||||
results: list[dict[str, Any]] = []
|
||||
for item in response.get("modelSummaries", []):
|
||||
model_id = item.get("modelId")
|
||||
if not model_id:
|
||||
continue
|
||||
input_modalities = set(item.get("inputModalities") or [])
|
||||
output_modalities = set(item.get("outputModalities") or [])
|
||||
results.append(
|
||||
{
|
||||
"model_id": model_id,
|
||||
"display_name": item.get("modelName") or model_id,
|
||||
"source": ModelSource.DISCOVERED,
|
||||
"supports_chat": "TEXT" in input_modalities and "TEXT" in output_modalities,
|
||||
"supports_image_input": "IMAGE" in input_modalities,
|
||||
"supports_tools": None,
|
||||
"supports_image_generation": "IMAGE" in output_modalities,
|
||||
"max_input_tokens": None,
|
||||
"metadata": item,
|
||||
}
|
||||
)
|
||||
return results
|
||||
|
||||
return await anyio.to_thread.run_sync(list_models)
|
||||
|
||||
|
||||
async def discover_models(conn: Connection) -> list[dict[str, Any]]:
|
||||
allowlist = _allowlist(conn)
|
||||
spec = spec_for(conn.provider)
|
||||
|
|
@ -304,6 +347,8 @@ async def discover_models(conn: Connection) -> list[dict[str, Any]]:
|
|||
results = await _discover_anthropic_models(conn)
|
||||
elif spec.discovery == "openai_models":
|
||||
results = await _discover_openai_shaped_models(conn, conn.base_url)
|
||||
elif spec.discovery == "bedrock_models":
|
||||
results = await _discover_bedrock_models(conn)
|
||||
elif spec.discovery == "static":
|
||||
results = _litellm_static_models(conn)
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -21,6 +21,7 @@ DiscoveryKind = Literal[
|
|||
"ollama",
|
||||
"openai_models",
|
||||
"anthropic_models",
|
||||
"bedrock_models",
|
||||
"openrouter",
|
||||
"static",
|
||||
"none",
|
||||
|
|
@ -51,7 +52,7 @@ REGISTRY: dict[str, ProviderSpec] = {
|
|||
Transport.NATIVE, "vertex_ai", "static", None, False, "native"
|
||||
),
|
||||
"bedrock": ProviderSpec(
|
||||
Transport.NATIVE, "bedrock", "static", None, False, "native"
|
||||
Transport.NATIVE, "bedrock", "bedrock_models", None, False, "native"
|
||||
),
|
||||
"openrouter": ProviderSpec(
|
||||
Transport.OPENAI_COMPATIBLE,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue