feat(model-connections): implement bulk model update endpoint and related schema changes

This commit is contained in:
Anish Sarkar 2026-06-12 09:43:56 +05:30
parent ad404b2dbc
commit ced1bb85ed
7 changed files with 538 additions and 168 deletions

View file

@ -1,7 +1,7 @@
import logging
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy import select
from sqlalchemy import select, update
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
@ -25,6 +25,7 @@ from app.schemas import (
ModelRead,
ModelRolesRead,
ModelRolesUpdate,
ModelsBulkUpdate,
ModelUpdate,
VerifyConnectionResponse,
)
@ -62,6 +63,7 @@ def _connection_read(conn: Connection | dict, models: list[Model | dict] | None
id=conn.id,
provider=conn.provider,
base_url=conn.base_url,
api_key=conn.api_key,
extra=conn.extra or {},
scope=conn.scope,
search_space_id=conn.search_space_id,
@ -351,6 +353,33 @@ async def add_manual_model(
return _model_read(model)
@router.patch("/model-connections/{connection_id}/models", response_model=list[ModelRead])
async def bulk_update_models(
connection_id: int,
data: ModelsBulkUpdate,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
conn = await _load_connection(session, connection_id)
await _assert_connection_access(session, user, conn, Permission.LLM_CONFIGS_UPDATE.value)
model_ids = set(data.model_ids)
await session.execute(
update(Model)
.where(Model.connection_id == connection_id, Model.id.in_(model_ids))
.values(enabled=data.enabled)
)
await session.commit()
session.expire_all()
result = await session.execute(
select(Model)
.where(Model.connection_id == connection_id, Model.id.in_(model_ids))
.order_by(Model.id)
)
return [_model_read(model) for model in result.scalars().all()]
@router.put("/models/{model_id}", response_model=ModelRead)
async def update_model(
model_id: int,

View file

@ -53,6 +53,7 @@ from .model_connections import (
ModelRead,
ModelRolesRead,
ModelRolesUpdate,
ModelsBulkUpdate,
ModelUpdate,
VerifyConnectionResponse,
)

View file

@ -32,6 +32,7 @@ class ConnectionRead(BaseModel):
id: int
provider: str
base_url: str | None = None
api_key: str | None = None
extra: dict[str, Any] = Field(default_factory=dict)
scope: ConnectionScope | str
search_space_id: int | None = None
@ -87,6 +88,11 @@ class ModelUpdate(BaseModel):
capabilities_override: dict[str, Any] | None = None
class ModelsBulkUpdate(BaseModel):
model_ids: list[int] = Field(..., min_length=1, max_length=1000)
enabled: bool
class ModelProviderRead(BaseModel):
provider: str
transport: str