feat(model-connections): enhance auto mode with auto pinning

This commit is contained in:
Anish Sarkar 2026-06-13 01:37:12 +05:30
parent 7a1bb2acd6
commit 45d27ba879
4 changed files with 196 additions and 19 deletions

View file

@ -61,9 +61,21 @@ def _create_index_if_missing(
op.create_index(index_name, table_name, columns, unique=False)
def _add_searchspace_column_if_missing(column_name: str) -> None:
def _add_searchspace_column_if_missing(
column_name: str,
*,
server_default: object | None = None,
) -> None:
if not _column_exists("searchspaces", column_name):
op.add_column("searchspaces", sa.Column(column_name, sa.Integer(), nullable=True))
op.add_column(
"searchspaces",
sa.Column(
column_name,
sa.Integer(),
nullable=True,
server_default=server_default,
),
)
def _drop_column_if_exists(table_name: str, column_name: str) -> None:
@ -233,9 +245,26 @@ def upgrade() -> None:
_create_index_if_missing("ix_models_model_id", "models", ["model_id"])
_create_index_if_missing("ix_models_billing_tier", "models", ["billing_tier"])
_add_searchspace_column_if_missing("chat_model_id")
_add_searchspace_column_if_missing("image_gen_model_id")
_add_searchspace_column_if_missing("vision_model_id")
_add_searchspace_column_if_missing("chat_model_id", server_default=sa.text("0"))
_add_searchspace_column_if_missing("image_gen_model_id", server_default=sa.text("0"))
_add_searchspace_column_if_missing("vision_model_id", server_default=sa.text("0"))
for column_name in ("chat_model_id", "image_gen_model_id", "vision_model_id"):
op.alter_column(
"searchspaces",
column_name,
existing_type=sa.Integer(),
existing_nullable=True,
server_default=sa.text("0"),
)
op.execute(
"""
UPDATE searchspaces
SET
chat_model_id = COALESCE(chat_model_id, 0),
image_gen_model_id = COALESCE(image_gen_model_id, 0),
vision_model_id = COALESCE(vision_model_id, 0)
"""
)
op.execute("DROP TYPE IF EXISTS connectionprotocol")

View file

@ -1853,13 +1853,13 @@ class SearchSpace(BaseModel, TimestampMixin):
# - Negative IDs: Global virtual models from global_llm_config.yaml
# - Positive IDs: User/search-space models from the models table
chat_model_id = Column(
Integer, nullable=True, default=0
Integer, nullable=True, default=0, server_default="0"
) # For agent/chat operations, defaults to Auto mode
image_gen_model_id = Column(
Integer, nullable=True, default=0
Integer, nullable=True, default=0, server_default="0"
) # For image generation, defaults to Auto mode when eligible
vision_model_id = Column(
Integer, nullable=True, default=0
Integer, nullable=True, default=0, server_default="0"
) # For vision/screenshot analysis, defaults to Auto mode
ai_file_sort_enabled = Column(

View file

@ -131,6 +131,95 @@ def _default_model_for(models: list[Model], capability: str) -> int | None:
return None
async def _load_role_model(
session: AsyncSession,
search_space_id: int,
model_id: int,
) -> Model | dict | None:
if model_id < 0:
return next(
(model for model in config.GLOBAL_MODELS if model.get("id") == model_id),
None,
)
result = await session.execute(
select(Model)
.options(selectinload(Model.connection))
.where(Model.id == model_id)
)
model = result.scalars().first()
if model is None or model.connection.search_space_id != search_space_id:
return None
return model
def _role_model_enabled(model: Model | dict) -> bool:
if isinstance(model, dict):
return bool(model.get("enabled", True))
return bool(model.enabled and model.connection.enabled)
async def _validate_role_model_id(
session: AsyncSession,
*,
search_space_id: int,
model_id: int | None,
capability: str,
) -> int:
if model_id is None or model_id == 0:
return 0
model = await _load_role_model(session, search_space_id, model_id)
if model and _role_model_enabled(model) and has_capability(model, capability):
return model_id
raise HTTPException(
status_code=400,
detail=f"Selected model is not available for {capability}",
)
async def _resolve_role_model_id(
session: AsyncSession,
*,
search_space_id: int,
model_id: int | None,
capability: str,
) -> int:
try:
return await _validate_role_model_id(
session,
search_space_id=search_space_id,
model_id=model_id,
capability=capability,
)
except HTTPException:
return 0
async def _clear_invalid_roles(session: AsyncSession, search_space_id: int) -> SearchSpace:
search_space = await _get_search_space(session, search_space_id)
search_space.chat_model_id = await _resolve_role_model_id(
session,
search_space_id=search_space_id,
model_id=search_space.chat_model_id,
capability="chat",
)
search_space.vision_model_id = await _resolve_role_model_id(
session,
search_space_id=search_space_id,
model_id=search_space.vision_model_id,
capability="vision",
)
search_space.image_gen_model_id = await _resolve_role_model_id(
session,
search_space_id=search_space_id,
model_id=search_space.image_gen_model_id,
capability="image_gen",
)
return search_space
async def _default_unset_roles(
session: AsyncSession,
conn: Connection,
@ -372,9 +461,13 @@ async def update_connection(
):
conn = await _load_connection(session, connection_id)
await _assert_connection_access(session, user, conn, Permission.LLM_CONFIGS_UPDATE.value)
search_space_id = conn.search_space_id
for key, value in data.model_dump(exclude_unset=True).items():
setattr(conn, key, value)
await session.commit()
if search_space_id is not None:
await _clear_invalid_roles(session, search_space_id)
await session.commit()
conn = await _load_connection(session, connection_id)
return _connection_read(conn, list(conn.models))
@ -387,8 +480,12 @@ async def delete_connection(
):
conn = await _load_connection(session, connection_id)
await _assert_connection_access(session, user, conn, Permission.LLM_CONFIGS_DELETE.value)
search_space_id = conn.search_space_id
await session.delete(conn)
await session.commit()
if search_space_id is not None:
await _clear_invalid_roles(session, search_space_id)
await session.commit()
return {"status": "deleted"}
@ -439,6 +536,8 @@ async def discover_connection_models(
await session.commit()
conn = await _load_connection(session, connection_id)
await _default_unset_roles(session, conn, list(conn.models))
if conn.search_space_id is not None:
await _clear_invalid_roles(session, conn.search_space_id)
await session.commit()
conn = await _load_connection(session, connection_id)
return [_model_read(model) for model in conn.models]
@ -476,7 +575,10 @@ async def add_manual_model(
await session.refresh(model)
conn = await _load_connection(session, connection_id)
await _default_unset_roles(session, conn, list(conn.models))
if conn.search_space_id is not None:
await _clear_invalid_roles(session, conn.search_space_id)
await session.commit()
await session.refresh(model)
return _model_read(model)
@ -489,6 +591,7 @@ async def bulk_update_models(
):
conn = await _load_connection(session, connection_id)
await _assert_connection_access(session, user, conn, Permission.LLM_CONFIGS_UPDATE.value)
search_space_id = conn.search_space_id
model_ids = set(data.model_ids)
await session.execute(
@ -498,6 +601,10 @@ async def bulk_update_models(
)
await session.commit()
session.expire_all()
if search_space_id is not None:
await _clear_invalid_roles(session, search_space_id)
await session.commit()
session.expire_all()
result = await session.execute(
select(Model)
@ -521,11 +628,16 @@ async def update_model(
if not model:
raise HTTPException(status_code=404, detail="Model not found")
await _assert_connection_access(session, user, model.connection, Permission.LLM_CONFIGS_UPDATE.value)
search_space_id = model.connection.search_space_id
update = data.model_dump(exclude_unset=True)
for key, value in update.items():
setattr(model, key, value)
await session.commit()
await session.refresh(model)
if search_space_id is not None:
await _clear_invalid_roles(session, search_space_id)
await session.commit()
await session.refresh(model)
return _model_read(model)
@ -560,7 +672,9 @@ async def get_model_roles(
Permission.LLM_CONFIGS_CREATE.value,
"You don't have permission to view model roles in this search space",
)
search_space = await _get_search_space(session, search_space_id)
search_space = await _clear_invalid_roles(session, search_space_id)
await session.commit()
await session.refresh(search_space)
return ModelRolesRead(
chat_model_id=search_space.chat_model_id,
vision_model_id=search_space.vision_model_id,
@ -583,8 +697,28 @@ async def update_model_roles(
"You don't have permission to update model roles in this search space",
)
search_space = await _get_search_space(session, search_space_id)
for key, value in data.model_dump(exclude_unset=True).items():
setattr(search_space, key, value)
updates = data.model_dump(exclude_unset=True)
if "chat_model_id" in updates:
search_space.chat_model_id = await _validate_role_model_id(
session,
search_space_id=search_space_id,
model_id=updates["chat_model_id"],
capability="chat",
)
if "vision_model_id" in updates:
search_space.vision_model_id = await _validate_role_model_id(
session,
search_space_id=search_space_id,
model_id=updates["vision_model_id"],
capability="vision",
)
if "image_gen_model_id" in updates:
search_space.image_gen_model_id = await _validate_role_model_id(
session,
search_space_id=search_space_id,
model_id=updates["image_gen_model_id"],
capability="image_gen",
)
await session.commit()
await session.refresh(search_space)
return ModelRolesRead(

View file

@ -65,6 +65,11 @@ function flattenModels(connections: ConnectionRead[]) {
);
}
function roleSelectValue(modelId: number | null | undefined, models: Array<{ id: number }>) {
if (!modelId) return "0";
return models.some((model) => model.id === modelId) ? String(modelId) : "0";
}
function ConnectionCard({ connection }: { connection: ConnectionRead }) {
const deleteConnection = useAtomValue(deleteModelConnectionMutationAtom);
@ -349,8 +354,8 @@ export function ModelConnectionsSettings({ searchSpaceId }: { searchSpaceId: num
<div className="flex flex-col gap-6">
<div className="flex flex-col gap-4">
<div>
<h3 className="text-sm font-semibold">Model Roles</h3>
<p className="text-xs text-muted-foreground">
<h3 className="text-base font-semibold">Model Roles</h3>
<p className="text-sm text-muted-foreground">
Pick which enabled model powers chat, vision, and image generation for this search
space.
</p>
@ -358,8 +363,12 @@ export function ModelConnectionsSettings({ searchSpaceId }: { searchSpaceId: num
<div className="flex w-full max-w-2xl flex-col gap-4">
<div className="flex flex-col gap-2">
<Label>Chat model</Label>
<p className="text-xs text-muted-foreground">
Primary model for chat responses and agent tasks. You can also change it from the
chat.
</p>
<Select
value={String(roles?.chat_model_id ?? 0)}
value={roleSelectValue(roles?.chat_model_id, chatModels)}
onValueChange={(value) => updateRoles.mutate({ chat_model_id: Number(value) })}
>
<SelectTrigger className="w-full">
@ -373,8 +382,12 @@ export function ModelConnectionsSettings({ searchSpaceId }: { searchSpaceId: num
</div>
<div className="flex flex-col gap-2">
<Label>Vision model</Label>
<p className="text-xs text-muted-foreground">
Used to understand images in uploads, documents, connectors, and automations. Falls
back to chat model when possible.
</p>
<Select
value={String(roles?.vision_model_id ?? 0)}
value={roleSelectValue(roles?.vision_model_id, visionModels)}
onValueChange={(value) => updateRoles.mutate({ vision_model_id: Number(value) })}
>
<SelectTrigger className="w-full">
@ -388,8 +401,9 @@ export function ModelConnectionsSettings({ searchSpaceId }: { searchSpaceId: num
</div>
<div className="flex flex-col gap-2">
<Label>Image generation model</Label>
<p className="text-xs text-muted-foreground">Used when generating images in chat.</p>
<Select
value={String(roles?.image_gen_model_id ?? 0)}
value={roleSelectValue(roles?.image_gen_model_id, imageModels)}
onValueChange={(value) => updateRoles.mutate({ image_gen_model_id: Number(value) })}
>
<SelectTrigger className="w-full">
@ -409,8 +423,8 @@ export function ModelConnectionsSettings({ searchSpaceId }: { searchSpaceId: num
<div className="flex flex-col gap-6">
<div className="flex flex-col gap-3">
<div>
<h3 className="text-sm font-semibold">Add Provider</h3>
<p className="text-xs text-muted-foreground">
<h3 className="text-base font-semibold">Add Provider</h3>
<p className="text-sm text-muted-foreground">
SurfSense supports popular providers and self-hosted model endpoints.
</p>
</div>
@ -462,7 +476,7 @@ export function ModelConnectionsSettings({ searchSpaceId }: { searchSpaceId: num
{connections.length > 0 ? (
<div className="flex flex-col gap-3">
<Separator />
<h3 className="text-sm font-semibold">Available Providers</h3>
<h3 className="text-base font-semibold">Available Providers</h3>
<div className="flex flex-col gap-3">
{connections.map((connection) => (
<ConnectionCard key={connection.id} connection={connection} />