mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-06-14 20:55:15 +02:00
feat(model-connections): enhance auto mode with auto pinning
This commit is contained in:
parent
7a1bb2acd6
commit
45d27ba879
4 changed files with 196 additions and 19 deletions
|
|
@ -61,9 +61,21 @@ def _create_index_if_missing(
|
|||
op.create_index(index_name, table_name, columns, unique=False)
|
||||
|
||||
|
||||
def _add_searchspace_column_if_missing(column_name: str) -> None:
|
||||
def _add_searchspace_column_if_missing(
|
||||
column_name: str,
|
||||
*,
|
||||
server_default: object | None = None,
|
||||
) -> None:
|
||||
if not _column_exists("searchspaces", column_name):
|
||||
op.add_column("searchspaces", sa.Column(column_name, sa.Integer(), nullable=True))
|
||||
op.add_column(
|
||||
"searchspaces",
|
||||
sa.Column(
|
||||
column_name,
|
||||
sa.Integer(),
|
||||
nullable=True,
|
||||
server_default=server_default,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _drop_column_if_exists(table_name: str, column_name: str) -> None:
|
||||
|
|
@ -233,9 +245,26 @@ def upgrade() -> None:
|
|||
_create_index_if_missing("ix_models_model_id", "models", ["model_id"])
|
||||
_create_index_if_missing("ix_models_billing_tier", "models", ["billing_tier"])
|
||||
|
||||
_add_searchspace_column_if_missing("chat_model_id")
|
||||
_add_searchspace_column_if_missing("image_gen_model_id")
|
||||
_add_searchspace_column_if_missing("vision_model_id")
|
||||
_add_searchspace_column_if_missing("chat_model_id", server_default=sa.text("0"))
|
||||
_add_searchspace_column_if_missing("image_gen_model_id", server_default=sa.text("0"))
|
||||
_add_searchspace_column_if_missing("vision_model_id", server_default=sa.text("0"))
|
||||
for column_name in ("chat_model_id", "image_gen_model_id", "vision_model_id"):
|
||||
op.alter_column(
|
||||
"searchspaces",
|
||||
column_name,
|
||||
existing_type=sa.Integer(),
|
||||
existing_nullable=True,
|
||||
server_default=sa.text("0"),
|
||||
)
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE searchspaces
|
||||
SET
|
||||
chat_model_id = COALESCE(chat_model_id, 0),
|
||||
image_gen_model_id = COALESCE(image_gen_model_id, 0),
|
||||
vision_model_id = COALESCE(vision_model_id, 0)
|
||||
"""
|
||||
)
|
||||
|
||||
op.execute("DROP TYPE IF EXISTS connectionprotocol")
|
||||
|
||||
|
|
|
|||
|
|
@ -1853,13 +1853,13 @@ class SearchSpace(BaseModel, TimestampMixin):
|
|||
# - Negative IDs: Global virtual models from global_llm_config.yaml
|
||||
# - Positive IDs: User/search-space models from the models table
|
||||
chat_model_id = Column(
|
||||
Integer, nullable=True, default=0
|
||||
Integer, nullable=True, default=0, server_default="0"
|
||||
) # For agent/chat operations, defaults to Auto mode
|
||||
image_gen_model_id = Column(
|
||||
Integer, nullable=True, default=0
|
||||
Integer, nullable=True, default=0, server_default="0"
|
||||
) # For image generation, defaults to Auto mode when eligible
|
||||
vision_model_id = Column(
|
||||
Integer, nullable=True, default=0
|
||||
Integer, nullable=True, default=0, server_default="0"
|
||||
) # For vision/screenshot analysis, defaults to Auto mode
|
||||
|
||||
ai_file_sort_enabled = Column(
|
||||
|
|
|
|||
|
|
@ -131,6 +131,95 @@ def _default_model_for(models: list[Model], capability: str) -> int | None:
|
|||
return None
|
||||
|
||||
|
||||
async def _load_role_model(
|
||||
session: AsyncSession,
|
||||
search_space_id: int,
|
||||
model_id: int,
|
||||
) -> Model | dict | None:
|
||||
if model_id < 0:
|
||||
return next(
|
||||
(model for model in config.GLOBAL_MODELS if model.get("id") == model_id),
|
||||
None,
|
||||
)
|
||||
|
||||
result = await session.execute(
|
||||
select(Model)
|
||||
.options(selectinload(Model.connection))
|
||||
.where(Model.id == model_id)
|
||||
)
|
||||
model = result.scalars().first()
|
||||
if model is None or model.connection.search_space_id != search_space_id:
|
||||
return None
|
||||
return model
|
||||
|
||||
|
||||
def _role_model_enabled(model: Model | dict) -> bool:
|
||||
if isinstance(model, dict):
|
||||
return bool(model.get("enabled", True))
|
||||
return bool(model.enabled and model.connection.enabled)
|
||||
|
||||
|
||||
async def _validate_role_model_id(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
search_space_id: int,
|
||||
model_id: int | None,
|
||||
capability: str,
|
||||
) -> int:
|
||||
if model_id is None or model_id == 0:
|
||||
return 0
|
||||
|
||||
model = await _load_role_model(session, search_space_id, model_id)
|
||||
if model and _role_model_enabled(model) and has_capability(model, capability):
|
||||
return model_id
|
||||
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Selected model is not available for {capability}",
|
||||
)
|
||||
|
||||
|
||||
async def _resolve_role_model_id(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
search_space_id: int,
|
||||
model_id: int | None,
|
||||
capability: str,
|
||||
) -> int:
|
||||
try:
|
||||
return await _validate_role_model_id(
|
||||
session,
|
||||
search_space_id=search_space_id,
|
||||
model_id=model_id,
|
||||
capability=capability,
|
||||
)
|
||||
except HTTPException:
|
||||
return 0
|
||||
|
||||
|
||||
async def _clear_invalid_roles(session: AsyncSession, search_space_id: int) -> SearchSpace:
|
||||
search_space = await _get_search_space(session, search_space_id)
|
||||
search_space.chat_model_id = await _resolve_role_model_id(
|
||||
session,
|
||||
search_space_id=search_space_id,
|
||||
model_id=search_space.chat_model_id,
|
||||
capability="chat",
|
||||
)
|
||||
search_space.vision_model_id = await _resolve_role_model_id(
|
||||
session,
|
||||
search_space_id=search_space_id,
|
||||
model_id=search_space.vision_model_id,
|
||||
capability="vision",
|
||||
)
|
||||
search_space.image_gen_model_id = await _resolve_role_model_id(
|
||||
session,
|
||||
search_space_id=search_space_id,
|
||||
model_id=search_space.image_gen_model_id,
|
||||
capability="image_gen",
|
||||
)
|
||||
return search_space
|
||||
|
||||
|
||||
async def _default_unset_roles(
|
||||
session: AsyncSession,
|
||||
conn: Connection,
|
||||
|
|
@ -372,9 +461,13 @@ async def update_connection(
|
|||
):
|
||||
conn = await _load_connection(session, connection_id)
|
||||
await _assert_connection_access(session, user, conn, Permission.LLM_CONFIGS_UPDATE.value)
|
||||
search_space_id = conn.search_space_id
|
||||
for key, value in data.model_dump(exclude_unset=True).items():
|
||||
setattr(conn, key, value)
|
||||
await session.commit()
|
||||
if search_space_id is not None:
|
||||
await _clear_invalid_roles(session, search_space_id)
|
||||
await session.commit()
|
||||
conn = await _load_connection(session, connection_id)
|
||||
return _connection_read(conn, list(conn.models))
|
||||
|
||||
|
|
@ -387,8 +480,12 @@ async def delete_connection(
|
|||
):
|
||||
conn = await _load_connection(session, connection_id)
|
||||
await _assert_connection_access(session, user, conn, Permission.LLM_CONFIGS_DELETE.value)
|
||||
search_space_id = conn.search_space_id
|
||||
await session.delete(conn)
|
||||
await session.commit()
|
||||
if search_space_id is not None:
|
||||
await _clear_invalid_roles(session, search_space_id)
|
||||
await session.commit()
|
||||
return {"status": "deleted"}
|
||||
|
||||
|
||||
|
|
@ -439,6 +536,8 @@ async def discover_connection_models(
|
|||
await session.commit()
|
||||
conn = await _load_connection(session, connection_id)
|
||||
await _default_unset_roles(session, conn, list(conn.models))
|
||||
if conn.search_space_id is not None:
|
||||
await _clear_invalid_roles(session, conn.search_space_id)
|
||||
await session.commit()
|
||||
conn = await _load_connection(session, connection_id)
|
||||
return [_model_read(model) for model in conn.models]
|
||||
|
|
@ -476,7 +575,10 @@ async def add_manual_model(
|
|||
await session.refresh(model)
|
||||
conn = await _load_connection(session, connection_id)
|
||||
await _default_unset_roles(session, conn, list(conn.models))
|
||||
if conn.search_space_id is not None:
|
||||
await _clear_invalid_roles(session, conn.search_space_id)
|
||||
await session.commit()
|
||||
await session.refresh(model)
|
||||
return _model_read(model)
|
||||
|
||||
|
||||
|
|
@ -489,6 +591,7 @@ async def bulk_update_models(
|
|||
):
|
||||
conn = await _load_connection(session, connection_id)
|
||||
await _assert_connection_access(session, user, conn, Permission.LLM_CONFIGS_UPDATE.value)
|
||||
search_space_id = conn.search_space_id
|
||||
|
||||
model_ids = set(data.model_ids)
|
||||
await session.execute(
|
||||
|
|
@ -498,6 +601,10 @@ async def bulk_update_models(
|
|||
)
|
||||
await session.commit()
|
||||
session.expire_all()
|
||||
if search_space_id is not None:
|
||||
await _clear_invalid_roles(session, search_space_id)
|
||||
await session.commit()
|
||||
session.expire_all()
|
||||
|
||||
result = await session.execute(
|
||||
select(Model)
|
||||
|
|
@ -521,11 +628,16 @@ async def update_model(
|
|||
if not model:
|
||||
raise HTTPException(status_code=404, detail="Model not found")
|
||||
await _assert_connection_access(session, user, model.connection, Permission.LLM_CONFIGS_UPDATE.value)
|
||||
search_space_id = model.connection.search_space_id
|
||||
update = data.model_dump(exclude_unset=True)
|
||||
for key, value in update.items():
|
||||
setattr(model, key, value)
|
||||
await session.commit()
|
||||
await session.refresh(model)
|
||||
if search_space_id is not None:
|
||||
await _clear_invalid_roles(session, search_space_id)
|
||||
await session.commit()
|
||||
await session.refresh(model)
|
||||
return _model_read(model)
|
||||
|
||||
|
||||
|
|
@ -560,7 +672,9 @@ async def get_model_roles(
|
|||
Permission.LLM_CONFIGS_CREATE.value,
|
||||
"You don't have permission to view model roles in this search space",
|
||||
)
|
||||
search_space = await _get_search_space(session, search_space_id)
|
||||
search_space = await _clear_invalid_roles(session, search_space_id)
|
||||
await session.commit()
|
||||
await session.refresh(search_space)
|
||||
return ModelRolesRead(
|
||||
chat_model_id=search_space.chat_model_id,
|
||||
vision_model_id=search_space.vision_model_id,
|
||||
|
|
@ -583,8 +697,28 @@ async def update_model_roles(
|
|||
"You don't have permission to update model roles in this search space",
|
||||
)
|
||||
search_space = await _get_search_space(session, search_space_id)
|
||||
for key, value in data.model_dump(exclude_unset=True).items():
|
||||
setattr(search_space, key, value)
|
||||
updates = data.model_dump(exclude_unset=True)
|
||||
if "chat_model_id" in updates:
|
||||
search_space.chat_model_id = await _validate_role_model_id(
|
||||
session,
|
||||
search_space_id=search_space_id,
|
||||
model_id=updates["chat_model_id"],
|
||||
capability="chat",
|
||||
)
|
||||
if "vision_model_id" in updates:
|
||||
search_space.vision_model_id = await _validate_role_model_id(
|
||||
session,
|
||||
search_space_id=search_space_id,
|
||||
model_id=updates["vision_model_id"],
|
||||
capability="vision",
|
||||
)
|
||||
if "image_gen_model_id" in updates:
|
||||
search_space.image_gen_model_id = await _validate_role_model_id(
|
||||
session,
|
||||
search_space_id=search_space_id,
|
||||
model_id=updates["image_gen_model_id"],
|
||||
capability="image_gen",
|
||||
)
|
||||
await session.commit()
|
||||
await session.refresh(search_space)
|
||||
return ModelRolesRead(
|
||||
|
|
|
|||
|
|
@ -65,6 +65,11 @@ function flattenModels(connections: ConnectionRead[]) {
|
|||
);
|
||||
}
|
||||
|
||||
function roleSelectValue(modelId: number | null | undefined, models: Array<{ id: number }>) {
|
||||
if (!modelId) return "0";
|
||||
return models.some((model) => model.id === modelId) ? String(modelId) : "0";
|
||||
}
|
||||
|
||||
function ConnectionCard({ connection }: { connection: ConnectionRead }) {
|
||||
const deleteConnection = useAtomValue(deleteModelConnectionMutationAtom);
|
||||
|
||||
|
|
@ -349,8 +354,8 @@ export function ModelConnectionsSettings({ searchSpaceId }: { searchSpaceId: num
|
|||
<div className="flex flex-col gap-6">
|
||||
<div className="flex flex-col gap-4">
|
||||
<div>
|
||||
<h3 className="text-sm font-semibold">Model Roles</h3>
|
||||
<p className="text-xs text-muted-foreground">
|
||||
<h3 className="text-base font-semibold">Model Roles</h3>
|
||||
<p className="text-sm text-muted-foreground">
|
||||
Pick which enabled model powers chat, vision, and image generation for this search
|
||||
space.
|
||||
</p>
|
||||
|
|
@ -358,8 +363,12 @@ export function ModelConnectionsSettings({ searchSpaceId }: { searchSpaceId: num
|
|||
<div className="flex w-full max-w-2xl flex-col gap-4">
|
||||
<div className="flex flex-col gap-2">
|
||||
<Label>Chat model</Label>
|
||||
<p className="text-xs text-muted-foreground">
|
||||
Primary model for chat responses and agent tasks. You can also change it from the
|
||||
chat.
|
||||
</p>
|
||||
<Select
|
||||
value={String(roles?.chat_model_id ?? 0)}
|
||||
value={roleSelectValue(roles?.chat_model_id, chatModels)}
|
||||
onValueChange={(value) => updateRoles.mutate({ chat_model_id: Number(value) })}
|
||||
>
|
||||
<SelectTrigger className="w-full">
|
||||
|
|
@ -373,8 +382,12 @@ export function ModelConnectionsSettings({ searchSpaceId }: { searchSpaceId: num
|
|||
</div>
|
||||
<div className="flex flex-col gap-2">
|
||||
<Label>Vision model</Label>
|
||||
<p className="text-xs text-muted-foreground">
|
||||
Used to understand images in uploads, documents, connectors, and automations. Falls
|
||||
back to chat model when possible.
|
||||
</p>
|
||||
<Select
|
||||
value={String(roles?.vision_model_id ?? 0)}
|
||||
value={roleSelectValue(roles?.vision_model_id, visionModels)}
|
||||
onValueChange={(value) => updateRoles.mutate({ vision_model_id: Number(value) })}
|
||||
>
|
||||
<SelectTrigger className="w-full">
|
||||
|
|
@ -388,8 +401,9 @@ export function ModelConnectionsSettings({ searchSpaceId }: { searchSpaceId: num
|
|||
</div>
|
||||
<div className="flex flex-col gap-2">
|
||||
<Label>Image generation model</Label>
|
||||
<p className="text-xs text-muted-foreground">Used when generating images in chat.</p>
|
||||
<Select
|
||||
value={String(roles?.image_gen_model_id ?? 0)}
|
||||
value={roleSelectValue(roles?.image_gen_model_id, imageModels)}
|
||||
onValueChange={(value) => updateRoles.mutate({ image_gen_model_id: Number(value) })}
|
||||
>
|
||||
<SelectTrigger className="w-full">
|
||||
|
|
@ -409,8 +423,8 @@ export function ModelConnectionsSettings({ searchSpaceId }: { searchSpaceId: num
|
|||
<div className="flex flex-col gap-6">
|
||||
<div className="flex flex-col gap-3">
|
||||
<div>
|
||||
<h3 className="text-sm font-semibold">Add Provider</h3>
|
||||
<p className="text-xs text-muted-foreground">
|
||||
<h3 className="text-base font-semibold">Add Provider</h3>
|
||||
<p className="text-sm text-muted-foreground">
|
||||
SurfSense supports popular providers and self-hosted model endpoints.
|
||||
</p>
|
||||
</div>
|
||||
|
|
@ -462,7 +476,7 @@ export function ModelConnectionsSettings({ searchSpaceId }: { searchSpaceId: num
|
|||
{connections.length > 0 ? (
|
||||
<div className="flex flex-col gap-3">
|
||||
<Separator />
|
||||
<h3 className="text-sm font-semibold">Available Providers</h3>
|
||||
<h3 className="text-base font-semibold">Available Providers</h3>
|
||||
<div className="flex flex-col gap-3">
|
||||
{connections.map((connection) => (
|
||||
<ConnectionCard key={connection.id} connection={connection} />
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue