mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-06-14 20:55:15 +02:00
feat(model-connections): implement bulk model update endpoint and related schema changes
This commit is contained in:
parent
ad404b2dbc
commit
ced1bb85ed
7 changed files with 538 additions and 168 deletions
|
|
@ -1,7 +1,7 @@
|
|||
import logging
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import select, update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
|
|
@ -25,6 +25,7 @@ from app.schemas import (
|
|||
ModelRead,
|
||||
ModelRolesRead,
|
||||
ModelRolesUpdate,
|
||||
ModelsBulkUpdate,
|
||||
ModelUpdate,
|
||||
VerifyConnectionResponse,
|
||||
)
|
||||
|
|
@ -62,6 +63,7 @@ def _connection_read(conn: Connection | dict, models: list[Model | dict] | None
|
|||
id=conn.id,
|
||||
provider=conn.provider,
|
||||
base_url=conn.base_url,
|
||||
api_key=conn.api_key,
|
||||
extra=conn.extra or {},
|
||||
scope=conn.scope,
|
||||
search_space_id=conn.search_space_id,
|
||||
|
|
@ -351,6 +353,33 @@ async def add_manual_model(
|
|||
return _model_read(model)
|
||||
|
||||
|
||||
@router.patch("/model-connections/{connection_id}/models", response_model=list[ModelRead])
|
||||
async def bulk_update_models(
|
||||
connection_id: int,
|
||||
data: ModelsBulkUpdate,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
conn = await _load_connection(session, connection_id)
|
||||
await _assert_connection_access(session, user, conn, Permission.LLM_CONFIGS_UPDATE.value)
|
||||
|
||||
model_ids = set(data.model_ids)
|
||||
await session.execute(
|
||||
update(Model)
|
||||
.where(Model.connection_id == connection_id, Model.id.in_(model_ids))
|
||||
.values(enabled=data.enabled)
|
||||
)
|
||||
await session.commit()
|
||||
session.expire_all()
|
||||
|
||||
result = await session.execute(
|
||||
select(Model)
|
||||
.where(Model.connection_id == connection_id, Model.id.in_(model_ids))
|
||||
.order_by(Model.id)
|
||||
)
|
||||
return [_model_read(model) for model in result.scalars().all()]
|
||||
|
||||
|
||||
@router.put("/models/{model_id}", response_model=ModelRead)
|
||||
async def update_model(
|
||||
model_id: int,
|
||||
|
|
|
|||
|
|
@ -53,6 +53,7 @@ from .model_connections import (
|
|||
ModelRead,
|
||||
ModelRolesRead,
|
||||
ModelRolesUpdate,
|
||||
ModelsBulkUpdate,
|
||||
ModelUpdate,
|
||||
VerifyConnectionResponse,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -32,6 +32,7 @@ class ConnectionRead(BaseModel):
|
|||
id: int
|
||||
provider: str
|
||||
base_url: str | None = None
|
||||
api_key: str | None = None
|
||||
extra: dict[str, Any] = Field(default_factory=dict)
|
||||
scope: ConnectionScope | str
|
||||
search_space_id: int | None = None
|
||||
|
|
@ -87,6 +88,11 @@ class ModelUpdate(BaseModel):
|
|||
capabilities_override: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class ModelsBulkUpdate(BaseModel):
|
||||
model_ids: list[int] = Field(..., min_length=1, max_length=1000)
|
||||
enabled: bool
|
||||
|
||||
|
||||
class ModelProviderRead(BaseModel):
|
||||
provider: str
|
||||
transport: str
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue