From 45d27ba87992b6ec9c07a0df3d24357e7f830d2e Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Sat, 13 Jun 2026 01:37:12 +0530 Subject: [PATCH] feat(model-connections): enhance auto mode with auto pinning --- .../versions/160_add_model_connections.py | 39 ++++- surfsense_backend/app/db.py | 6 +- .../app/routes/model_connections_routes.py | 140 +++++++++++++++++- .../settings/model-connections-settings.tsx | 30 +++- 4 files changed, 196 insertions(+), 19 deletions(-) diff --git a/surfsense_backend/alembic/versions/160_add_model_connections.py b/surfsense_backend/alembic/versions/160_add_model_connections.py index 49d6315ca..2c35bd568 100644 --- a/surfsense_backend/alembic/versions/160_add_model_connections.py +++ b/surfsense_backend/alembic/versions/160_add_model_connections.py @@ -61,9 +61,21 @@ def _create_index_if_missing( op.create_index(index_name, table_name, columns, unique=False) -def _add_searchspace_column_if_missing(column_name: str) -> None: +def _add_searchspace_column_if_missing( + column_name: str, + *, + server_default: object | None = None, +) -> None: if not _column_exists("searchspaces", column_name): - op.add_column("searchspaces", sa.Column(column_name, sa.Integer(), nullable=True)) + op.add_column( + "searchspaces", + sa.Column( + column_name, + sa.Integer(), + nullable=True, + server_default=server_default, + ), + ) def _drop_column_if_exists(table_name: str, column_name: str) -> None: @@ -233,9 +245,26 @@ def upgrade() -> None: _create_index_if_missing("ix_models_model_id", "models", ["model_id"]) _create_index_if_missing("ix_models_billing_tier", "models", ["billing_tier"]) - _add_searchspace_column_if_missing("chat_model_id") - _add_searchspace_column_if_missing("image_gen_model_id") - _add_searchspace_column_if_missing("vision_model_id") + _add_searchspace_column_if_missing("chat_model_id", server_default=sa.text("0")) + _add_searchspace_column_if_missing("image_gen_model_id", server_default=sa.text("0")) + _add_searchspace_column_if_missing("vision_model_id", server_default=sa.text("0")) + for column_name in ("chat_model_id", "image_gen_model_id", "vision_model_id"): + op.alter_column( + "searchspaces", + column_name, + existing_type=sa.Integer(), + existing_nullable=True, + server_default=sa.text("0"), + ) + op.execute( + """ + UPDATE searchspaces + SET + chat_model_id = COALESCE(chat_model_id, 0), + image_gen_model_id = COALESCE(image_gen_model_id, 0), + vision_model_id = COALESCE(vision_model_id, 0) + """ + ) op.execute("DROP TYPE IF EXISTS connectionprotocol") diff --git a/surfsense_backend/app/db.py b/surfsense_backend/app/db.py index eeed9932d..728031fa0 100644 --- a/surfsense_backend/app/db.py +++ b/surfsense_backend/app/db.py @@ -1853,13 +1853,13 @@ class SearchSpace(BaseModel, TimestampMixin): # - Negative IDs: Global virtual models from global_llm_config.yaml # - Positive IDs: User/search-space models from the models table chat_model_id = Column( - Integer, nullable=True, default=0 + Integer, nullable=True, default=0, server_default="0" ) # For agent/chat operations, defaults to Auto mode image_gen_model_id = Column( - Integer, nullable=True, default=0 + Integer, nullable=True, default=0, server_default="0" ) # For image generation, defaults to Auto mode when eligible vision_model_id = Column( - Integer, nullable=True, default=0 + Integer, nullable=True, default=0, server_default="0" ) # For vision/screenshot analysis, defaults to Auto mode ai_file_sort_enabled = Column( diff --git a/surfsense_backend/app/routes/model_connections_routes.py b/surfsense_backend/app/routes/model_connections_routes.py index 76e4a3dfb..6db9aa9f3 100644 --- a/surfsense_backend/app/routes/model_connections_routes.py +++ b/surfsense_backend/app/routes/model_connections_routes.py @@ -131,6 +131,95 @@ def _default_model_for(models: list[Model], capability: str) -> int | None: return None +async def _load_role_model( + session: AsyncSession, + search_space_id: int, + model_id: int, +) -> Model | dict | None: + if model_id < 0: + return next( + (model for model in config.GLOBAL_MODELS if model.get("id") == model_id), + None, + ) + + result = await session.execute( + select(Model) + .options(selectinload(Model.connection)) + .where(Model.id == model_id) + ) + model = result.scalars().first() + if model is None or model.connection.search_space_id != search_space_id: + return None + return model + + +def _role_model_enabled(model: Model | dict) -> bool: + if isinstance(model, dict): + return bool(model.get("enabled", True)) + return bool(model.enabled and model.connection.enabled) + + +async def _validate_role_model_id( + session: AsyncSession, + *, + search_space_id: int, + model_id: int | None, + capability: str, +) -> int: + if model_id is None or model_id == 0: + return 0 + + model = await _load_role_model(session, search_space_id, model_id) + if model and _role_model_enabled(model) and has_capability(model, capability): + return model_id + + raise HTTPException( + status_code=400, + detail=f"Selected model is not available for {capability}", + ) + + +async def _resolve_role_model_id( + session: AsyncSession, + *, + search_space_id: int, + model_id: int | None, + capability: str, +) -> int: + try: + return await _validate_role_model_id( + session, + search_space_id=search_space_id, + model_id=model_id, + capability=capability, + ) + except HTTPException: + return 0 + + +async def _clear_invalid_roles(session: AsyncSession, search_space_id: int) -> SearchSpace: + search_space = await _get_search_space(session, search_space_id) + search_space.chat_model_id = await _resolve_role_model_id( + session, + search_space_id=search_space_id, + model_id=search_space.chat_model_id, + capability="chat", + ) + search_space.vision_model_id = await _resolve_role_model_id( + session, + search_space_id=search_space_id, + model_id=search_space.vision_model_id, + capability="vision", + ) + search_space.image_gen_model_id = await _resolve_role_model_id( + session, + search_space_id=search_space_id, + model_id=search_space.image_gen_model_id, + capability="image_gen", + ) + return search_space + + async def _default_unset_roles( session: AsyncSession, conn: Connection, @@ -372,9 +461,13 @@ async def update_connection( ): conn = await _load_connection(session, connection_id) await _assert_connection_access(session, user, conn, Permission.LLM_CONFIGS_UPDATE.value) + search_space_id = conn.search_space_id for key, value in data.model_dump(exclude_unset=True).items(): setattr(conn, key, value) await session.commit() + if search_space_id is not None: + await _clear_invalid_roles(session, search_space_id) + await session.commit() conn = await _load_connection(session, connection_id) return _connection_read(conn, list(conn.models)) @@ -387,8 +480,12 @@ async def delete_connection( ): conn = await _load_connection(session, connection_id) await _assert_connection_access(session, user, conn, Permission.LLM_CONFIGS_DELETE.value) + search_space_id = conn.search_space_id await session.delete(conn) await session.commit() + if search_space_id is not None: + await _clear_invalid_roles(session, search_space_id) + await session.commit() return {"status": "deleted"} @@ -439,6 +536,8 @@ async def discover_connection_models( await session.commit() conn = await _load_connection(session, connection_id) await _default_unset_roles(session, conn, list(conn.models)) + if conn.search_space_id is not None: + await _clear_invalid_roles(session, conn.search_space_id) await session.commit() conn = await _load_connection(session, connection_id) return [_model_read(model) for model in conn.models] @@ -476,7 +575,10 @@ async def add_manual_model( await session.refresh(model) conn = await _load_connection(session, connection_id) await _default_unset_roles(session, conn, list(conn.models)) + if conn.search_space_id is not None: + await _clear_invalid_roles(session, conn.search_space_id) await session.commit() + await session.refresh(model) return _model_read(model) @@ -489,6 +591,7 @@ async def bulk_update_models( ): conn = await _load_connection(session, connection_id) await _assert_connection_access(session, user, conn, Permission.LLM_CONFIGS_UPDATE.value) + search_space_id = conn.search_space_id model_ids = set(data.model_ids) await session.execute( @@ -498,6 +601,10 @@ async def bulk_update_models( ) await session.commit() session.expire_all() + if search_space_id is not None: + await _clear_invalid_roles(session, search_space_id) + await session.commit() + session.expire_all() result = await session.execute( select(Model) @@ -521,11 +628,16 @@ async def update_model( if not model: raise HTTPException(status_code=404, detail="Model not found") await _assert_connection_access(session, user, model.connection, Permission.LLM_CONFIGS_UPDATE.value) + search_space_id = model.connection.search_space_id update = data.model_dump(exclude_unset=True) for key, value in update.items(): setattr(model, key, value) await session.commit() await session.refresh(model) + if search_space_id is not None: + await _clear_invalid_roles(session, search_space_id) + await session.commit() + await session.refresh(model) return _model_read(model) @@ -560,7 +672,9 @@ async def get_model_roles( Permission.LLM_CONFIGS_CREATE.value, "You don't have permission to view model roles in this search space", ) - search_space = await _get_search_space(session, search_space_id) + search_space = await _clear_invalid_roles(session, search_space_id) + await session.commit() + await session.refresh(search_space) return ModelRolesRead( chat_model_id=search_space.chat_model_id, vision_model_id=search_space.vision_model_id, @@ -583,8 +697,28 @@ async def update_model_roles( "You don't have permission to update model roles in this search space", ) search_space = await _get_search_space(session, search_space_id) - for key, value in data.model_dump(exclude_unset=True).items(): - setattr(search_space, key, value) + updates = data.model_dump(exclude_unset=True) + if "chat_model_id" in updates: + search_space.chat_model_id = await _validate_role_model_id( + session, + search_space_id=search_space_id, + model_id=updates["chat_model_id"], + capability="chat", + ) + if "vision_model_id" in updates: + search_space.vision_model_id = await _validate_role_model_id( + session, + search_space_id=search_space_id, + model_id=updates["vision_model_id"], + capability="vision", + ) + if "image_gen_model_id" in updates: + search_space.image_gen_model_id = await _validate_role_model_id( + session, + search_space_id=search_space_id, + model_id=updates["image_gen_model_id"], + capability="image_gen", + ) await session.commit() await session.refresh(search_space) return ModelRolesRead( diff --git a/surfsense_web/components/settings/model-connections-settings.tsx b/surfsense_web/components/settings/model-connections-settings.tsx index c61368974..1f5c166b3 100644 --- a/surfsense_web/components/settings/model-connections-settings.tsx +++ b/surfsense_web/components/settings/model-connections-settings.tsx @@ -65,6 +65,11 @@ function flattenModels(connections: ConnectionRead[]) { ); } +function roleSelectValue(modelId: number | null | undefined, models: Array<{ id: number }>) { + if (!modelId) return "0"; + return models.some((model) => model.id === modelId) ? String(modelId) : "0"; +} + function ConnectionCard({ connection }: { connection: ConnectionRead }) { const deleteConnection = useAtomValue(deleteModelConnectionMutationAtom); @@ -349,8 +354,8 @@ export function ModelConnectionsSettings({ searchSpaceId }: { searchSpaceId: num
-

Model Roles

-

+

Model Roles

+

Pick which enabled model powers chat, vision, and image generation for this search space.

@@ -358,8 +363,12 @@ export function ModelConnectionsSettings({ searchSpaceId }: { searchSpaceId: num
+

+ Primary model for chat responses and agent tasks. You can also change it from the + chat. +

updateRoles.mutate({ vision_model_id: Number(value) })} > @@ -388,8 +401,9 @@ export function ModelConnectionsSettings({ searchSpaceId }: { searchSpaceId: num
+

Used when generating images in chat.