mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-06-26 21:39:43 +02:00
feat(model-connections): enhance auto mode with auto pinning
This commit is contained in:
parent
7a1bb2acd6
commit
45d27ba879
4 changed files with 196 additions and 19 deletions
|
|
@ -61,9 +61,21 @@ def _create_index_if_missing(
|
||||||
op.create_index(index_name, table_name, columns, unique=False)
|
op.create_index(index_name, table_name, columns, unique=False)
|
||||||
|
|
||||||
|
|
||||||
def _add_searchspace_column_if_missing(column_name: str) -> None:
|
def _add_searchspace_column_if_missing(
|
||||||
|
column_name: str,
|
||||||
|
*,
|
||||||
|
server_default: object | None = None,
|
||||||
|
) -> None:
|
||||||
if not _column_exists("searchspaces", column_name):
|
if not _column_exists("searchspaces", column_name):
|
||||||
op.add_column("searchspaces", sa.Column(column_name, sa.Integer(), nullable=True))
|
op.add_column(
|
||||||
|
"searchspaces",
|
||||||
|
sa.Column(
|
||||||
|
column_name,
|
||||||
|
sa.Integer(),
|
||||||
|
nullable=True,
|
||||||
|
server_default=server_default,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _drop_column_if_exists(table_name: str, column_name: str) -> None:
|
def _drop_column_if_exists(table_name: str, column_name: str) -> None:
|
||||||
|
|
@ -233,9 +245,26 @@ def upgrade() -> None:
|
||||||
_create_index_if_missing("ix_models_model_id", "models", ["model_id"])
|
_create_index_if_missing("ix_models_model_id", "models", ["model_id"])
|
||||||
_create_index_if_missing("ix_models_billing_tier", "models", ["billing_tier"])
|
_create_index_if_missing("ix_models_billing_tier", "models", ["billing_tier"])
|
||||||
|
|
||||||
_add_searchspace_column_if_missing("chat_model_id")
|
_add_searchspace_column_if_missing("chat_model_id", server_default=sa.text("0"))
|
||||||
_add_searchspace_column_if_missing("image_gen_model_id")
|
_add_searchspace_column_if_missing("image_gen_model_id", server_default=sa.text("0"))
|
||||||
_add_searchspace_column_if_missing("vision_model_id")
|
_add_searchspace_column_if_missing("vision_model_id", server_default=sa.text("0"))
|
||||||
|
for column_name in ("chat_model_id", "image_gen_model_id", "vision_model_id"):
|
||||||
|
op.alter_column(
|
||||||
|
"searchspaces",
|
||||||
|
column_name,
|
||||||
|
existing_type=sa.Integer(),
|
||||||
|
existing_nullable=True,
|
||||||
|
server_default=sa.text("0"),
|
||||||
|
)
|
||||||
|
op.execute(
|
||||||
|
"""
|
||||||
|
UPDATE searchspaces
|
||||||
|
SET
|
||||||
|
chat_model_id = COALESCE(chat_model_id, 0),
|
||||||
|
image_gen_model_id = COALESCE(image_gen_model_id, 0),
|
||||||
|
vision_model_id = COALESCE(vision_model_id, 0)
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
op.execute("DROP TYPE IF EXISTS connectionprotocol")
|
op.execute("DROP TYPE IF EXISTS connectionprotocol")
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1853,13 +1853,13 @@ class SearchSpace(BaseModel, TimestampMixin):
|
||||||
# - Negative IDs: Global virtual models from global_llm_config.yaml
|
# - Negative IDs: Global virtual models from global_llm_config.yaml
|
||||||
# - Positive IDs: User/search-space models from the models table
|
# - Positive IDs: User/search-space models from the models table
|
||||||
chat_model_id = Column(
|
chat_model_id = Column(
|
||||||
Integer, nullable=True, default=0
|
Integer, nullable=True, default=0, server_default="0"
|
||||||
) # For agent/chat operations, defaults to Auto mode
|
) # For agent/chat operations, defaults to Auto mode
|
||||||
image_gen_model_id = Column(
|
image_gen_model_id = Column(
|
||||||
Integer, nullable=True, default=0
|
Integer, nullable=True, default=0, server_default="0"
|
||||||
) # For image generation, defaults to Auto mode when eligible
|
) # For image generation, defaults to Auto mode when eligible
|
||||||
vision_model_id = Column(
|
vision_model_id = Column(
|
||||||
Integer, nullable=True, default=0
|
Integer, nullable=True, default=0, server_default="0"
|
||||||
) # For vision/screenshot analysis, defaults to Auto mode
|
) # For vision/screenshot analysis, defaults to Auto mode
|
||||||
|
|
||||||
ai_file_sort_enabled = Column(
|
ai_file_sort_enabled = Column(
|
||||||
|
|
|
||||||
|
|
@ -131,6 +131,95 @@ def _default_model_for(models: list[Model], capability: str) -> int | None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
async def _load_role_model(
|
||||||
|
session: AsyncSession,
|
||||||
|
search_space_id: int,
|
||||||
|
model_id: int,
|
||||||
|
) -> Model | dict | None:
|
||||||
|
if model_id < 0:
|
||||||
|
return next(
|
||||||
|
(model for model in config.GLOBAL_MODELS if model.get("id") == model_id),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await session.execute(
|
||||||
|
select(Model)
|
||||||
|
.options(selectinload(Model.connection))
|
||||||
|
.where(Model.id == model_id)
|
||||||
|
)
|
||||||
|
model = result.scalars().first()
|
||||||
|
if model is None or model.connection.search_space_id != search_space_id:
|
||||||
|
return None
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def _role_model_enabled(model: Model | dict) -> bool:
|
||||||
|
if isinstance(model, dict):
|
||||||
|
return bool(model.get("enabled", True))
|
||||||
|
return bool(model.enabled and model.connection.enabled)
|
||||||
|
|
||||||
|
|
||||||
|
async def _validate_role_model_id(
|
||||||
|
session: AsyncSession,
|
||||||
|
*,
|
||||||
|
search_space_id: int,
|
||||||
|
model_id: int | None,
|
||||||
|
capability: str,
|
||||||
|
) -> int:
|
||||||
|
if model_id is None or model_id == 0:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
model = await _load_role_model(session, search_space_id, model_id)
|
||||||
|
if model and _role_model_enabled(model) and has_capability(model, capability):
|
||||||
|
return model_id
|
||||||
|
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail=f"Selected model is not available for {capability}",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def _resolve_role_model_id(
|
||||||
|
session: AsyncSession,
|
||||||
|
*,
|
||||||
|
search_space_id: int,
|
||||||
|
model_id: int | None,
|
||||||
|
capability: str,
|
||||||
|
) -> int:
|
||||||
|
try:
|
||||||
|
return await _validate_role_model_id(
|
||||||
|
session,
|
||||||
|
search_space_id=search_space_id,
|
||||||
|
model_id=model_id,
|
||||||
|
capability=capability,
|
||||||
|
)
|
||||||
|
except HTTPException:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
async def _clear_invalid_roles(session: AsyncSession, search_space_id: int) -> SearchSpace:
|
||||||
|
search_space = await _get_search_space(session, search_space_id)
|
||||||
|
search_space.chat_model_id = await _resolve_role_model_id(
|
||||||
|
session,
|
||||||
|
search_space_id=search_space_id,
|
||||||
|
model_id=search_space.chat_model_id,
|
||||||
|
capability="chat",
|
||||||
|
)
|
||||||
|
search_space.vision_model_id = await _resolve_role_model_id(
|
||||||
|
session,
|
||||||
|
search_space_id=search_space_id,
|
||||||
|
model_id=search_space.vision_model_id,
|
||||||
|
capability="vision",
|
||||||
|
)
|
||||||
|
search_space.image_gen_model_id = await _resolve_role_model_id(
|
||||||
|
session,
|
||||||
|
search_space_id=search_space_id,
|
||||||
|
model_id=search_space.image_gen_model_id,
|
||||||
|
capability="image_gen",
|
||||||
|
)
|
||||||
|
return search_space
|
||||||
|
|
||||||
|
|
||||||
async def _default_unset_roles(
|
async def _default_unset_roles(
|
||||||
session: AsyncSession,
|
session: AsyncSession,
|
||||||
conn: Connection,
|
conn: Connection,
|
||||||
|
|
@ -372,9 +461,13 @@ async def update_connection(
|
||||||
):
|
):
|
||||||
conn = await _load_connection(session, connection_id)
|
conn = await _load_connection(session, connection_id)
|
||||||
await _assert_connection_access(session, user, conn, Permission.LLM_CONFIGS_UPDATE.value)
|
await _assert_connection_access(session, user, conn, Permission.LLM_CONFIGS_UPDATE.value)
|
||||||
|
search_space_id = conn.search_space_id
|
||||||
for key, value in data.model_dump(exclude_unset=True).items():
|
for key, value in data.model_dump(exclude_unset=True).items():
|
||||||
setattr(conn, key, value)
|
setattr(conn, key, value)
|
||||||
await session.commit()
|
await session.commit()
|
||||||
|
if search_space_id is not None:
|
||||||
|
await _clear_invalid_roles(session, search_space_id)
|
||||||
|
await session.commit()
|
||||||
conn = await _load_connection(session, connection_id)
|
conn = await _load_connection(session, connection_id)
|
||||||
return _connection_read(conn, list(conn.models))
|
return _connection_read(conn, list(conn.models))
|
||||||
|
|
||||||
|
|
@ -387,8 +480,12 @@ async def delete_connection(
|
||||||
):
|
):
|
||||||
conn = await _load_connection(session, connection_id)
|
conn = await _load_connection(session, connection_id)
|
||||||
await _assert_connection_access(session, user, conn, Permission.LLM_CONFIGS_DELETE.value)
|
await _assert_connection_access(session, user, conn, Permission.LLM_CONFIGS_DELETE.value)
|
||||||
|
search_space_id = conn.search_space_id
|
||||||
await session.delete(conn)
|
await session.delete(conn)
|
||||||
await session.commit()
|
await session.commit()
|
||||||
|
if search_space_id is not None:
|
||||||
|
await _clear_invalid_roles(session, search_space_id)
|
||||||
|
await session.commit()
|
||||||
return {"status": "deleted"}
|
return {"status": "deleted"}
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -439,6 +536,8 @@ async def discover_connection_models(
|
||||||
await session.commit()
|
await session.commit()
|
||||||
conn = await _load_connection(session, connection_id)
|
conn = await _load_connection(session, connection_id)
|
||||||
await _default_unset_roles(session, conn, list(conn.models))
|
await _default_unset_roles(session, conn, list(conn.models))
|
||||||
|
if conn.search_space_id is not None:
|
||||||
|
await _clear_invalid_roles(session, conn.search_space_id)
|
||||||
await session.commit()
|
await session.commit()
|
||||||
conn = await _load_connection(session, connection_id)
|
conn = await _load_connection(session, connection_id)
|
||||||
return [_model_read(model) for model in conn.models]
|
return [_model_read(model) for model in conn.models]
|
||||||
|
|
@ -476,7 +575,10 @@ async def add_manual_model(
|
||||||
await session.refresh(model)
|
await session.refresh(model)
|
||||||
conn = await _load_connection(session, connection_id)
|
conn = await _load_connection(session, connection_id)
|
||||||
await _default_unset_roles(session, conn, list(conn.models))
|
await _default_unset_roles(session, conn, list(conn.models))
|
||||||
|
if conn.search_space_id is not None:
|
||||||
|
await _clear_invalid_roles(session, conn.search_space_id)
|
||||||
await session.commit()
|
await session.commit()
|
||||||
|
await session.refresh(model)
|
||||||
return _model_read(model)
|
return _model_read(model)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -489,6 +591,7 @@ async def bulk_update_models(
|
||||||
):
|
):
|
||||||
conn = await _load_connection(session, connection_id)
|
conn = await _load_connection(session, connection_id)
|
||||||
await _assert_connection_access(session, user, conn, Permission.LLM_CONFIGS_UPDATE.value)
|
await _assert_connection_access(session, user, conn, Permission.LLM_CONFIGS_UPDATE.value)
|
||||||
|
search_space_id = conn.search_space_id
|
||||||
|
|
||||||
model_ids = set(data.model_ids)
|
model_ids = set(data.model_ids)
|
||||||
await session.execute(
|
await session.execute(
|
||||||
|
|
@ -498,6 +601,10 @@ async def bulk_update_models(
|
||||||
)
|
)
|
||||||
await session.commit()
|
await session.commit()
|
||||||
session.expire_all()
|
session.expire_all()
|
||||||
|
if search_space_id is not None:
|
||||||
|
await _clear_invalid_roles(session, search_space_id)
|
||||||
|
await session.commit()
|
||||||
|
session.expire_all()
|
||||||
|
|
||||||
result = await session.execute(
|
result = await session.execute(
|
||||||
select(Model)
|
select(Model)
|
||||||
|
|
@ -521,11 +628,16 @@ async def update_model(
|
||||||
if not model:
|
if not model:
|
||||||
raise HTTPException(status_code=404, detail="Model not found")
|
raise HTTPException(status_code=404, detail="Model not found")
|
||||||
await _assert_connection_access(session, user, model.connection, Permission.LLM_CONFIGS_UPDATE.value)
|
await _assert_connection_access(session, user, model.connection, Permission.LLM_CONFIGS_UPDATE.value)
|
||||||
|
search_space_id = model.connection.search_space_id
|
||||||
update = data.model_dump(exclude_unset=True)
|
update = data.model_dump(exclude_unset=True)
|
||||||
for key, value in update.items():
|
for key, value in update.items():
|
||||||
setattr(model, key, value)
|
setattr(model, key, value)
|
||||||
await session.commit()
|
await session.commit()
|
||||||
await session.refresh(model)
|
await session.refresh(model)
|
||||||
|
if search_space_id is not None:
|
||||||
|
await _clear_invalid_roles(session, search_space_id)
|
||||||
|
await session.commit()
|
||||||
|
await session.refresh(model)
|
||||||
return _model_read(model)
|
return _model_read(model)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -560,7 +672,9 @@ async def get_model_roles(
|
||||||
Permission.LLM_CONFIGS_CREATE.value,
|
Permission.LLM_CONFIGS_CREATE.value,
|
||||||
"You don't have permission to view model roles in this search space",
|
"You don't have permission to view model roles in this search space",
|
||||||
)
|
)
|
||||||
search_space = await _get_search_space(session, search_space_id)
|
search_space = await _clear_invalid_roles(session, search_space_id)
|
||||||
|
await session.commit()
|
||||||
|
await session.refresh(search_space)
|
||||||
return ModelRolesRead(
|
return ModelRolesRead(
|
||||||
chat_model_id=search_space.chat_model_id,
|
chat_model_id=search_space.chat_model_id,
|
||||||
vision_model_id=search_space.vision_model_id,
|
vision_model_id=search_space.vision_model_id,
|
||||||
|
|
@ -583,8 +697,28 @@ async def update_model_roles(
|
||||||
"You don't have permission to update model roles in this search space",
|
"You don't have permission to update model roles in this search space",
|
||||||
)
|
)
|
||||||
search_space = await _get_search_space(session, search_space_id)
|
search_space = await _get_search_space(session, search_space_id)
|
||||||
for key, value in data.model_dump(exclude_unset=True).items():
|
updates = data.model_dump(exclude_unset=True)
|
||||||
setattr(search_space, key, value)
|
if "chat_model_id" in updates:
|
||||||
|
search_space.chat_model_id = await _validate_role_model_id(
|
||||||
|
session,
|
||||||
|
search_space_id=search_space_id,
|
||||||
|
model_id=updates["chat_model_id"],
|
||||||
|
capability="chat",
|
||||||
|
)
|
||||||
|
if "vision_model_id" in updates:
|
||||||
|
search_space.vision_model_id = await _validate_role_model_id(
|
||||||
|
session,
|
||||||
|
search_space_id=search_space_id,
|
||||||
|
model_id=updates["vision_model_id"],
|
||||||
|
capability="vision",
|
||||||
|
)
|
||||||
|
if "image_gen_model_id" in updates:
|
||||||
|
search_space.image_gen_model_id = await _validate_role_model_id(
|
||||||
|
session,
|
||||||
|
search_space_id=search_space_id,
|
||||||
|
model_id=updates["image_gen_model_id"],
|
||||||
|
capability="image_gen",
|
||||||
|
)
|
||||||
await session.commit()
|
await session.commit()
|
||||||
await session.refresh(search_space)
|
await session.refresh(search_space)
|
||||||
return ModelRolesRead(
|
return ModelRolesRead(
|
||||||
|
|
|
||||||
|
|
@ -65,6 +65,11 @@ function flattenModels(connections: ConnectionRead[]) {
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function roleSelectValue(modelId: number | null | undefined, models: Array<{ id: number }>) {
|
||||||
|
if (!modelId) return "0";
|
||||||
|
return models.some((model) => model.id === modelId) ? String(modelId) : "0";
|
||||||
|
}
|
||||||
|
|
||||||
function ConnectionCard({ connection }: { connection: ConnectionRead }) {
|
function ConnectionCard({ connection }: { connection: ConnectionRead }) {
|
||||||
const deleteConnection = useAtomValue(deleteModelConnectionMutationAtom);
|
const deleteConnection = useAtomValue(deleteModelConnectionMutationAtom);
|
||||||
|
|
||||||
|
|
@ -349,8 +354,8 @@ export function ModelConnectionsSettings({ searchSpaceId }: { searchSpaceId: num
|
||||||
<div className="flex flex-col gap-6">
|
<div className="flex flex-col gap-6">
|
||||||
<div className="flex flex-col gap-4">
|
<div className="flex flex-col gap-4">
|
||||||
<div>
|
<div>
|
||||||
<h3 className="text-sm font-semibold">Model Roles</h3>
|
<h3 className="text-base font-semibold">Model Roles</h3>
|
||||||
<p className="text-xs text-muted-foreground">
|
<p className="text-sm text-muted-foreground">
|
||||||
Pick which enabled model powers chat, vision, and image generation for this search
|
Pick which enabled model powers chat, vision, and image generation for this search
|
||||||
space.
|
space.
|
||||||
</p>
|
</p>
|
||||||
|
|
@ -358,8 +363,12 @@ export function ModelConnectionsSettings({ searchSpaceId }: { searchSpaceId: num
|
||||||
<div className="flex w-full max-w-2xl flex-col gap-4">
|
<div className="flex w-full max-w-2xl flex-col gap-4">
|
||||||
<div className="flex flex-col gap-2">
|
<div className="flex flex-col gap-2">
|
||||||
<Label>Chat model</Label>
|
<Label>Chat model</Label>
|
||||||
|
<p className="text-xs text-muted-foreground">
|
||||||
|
Primary model for chat responses and agent tasks. You can also change it from the
|
||||||
|
chat.
|
||||||
|
</p>
|
||||||
<Select
|
<Select
|
||||||
value={String(roles?.chat_model_id ?? 0)}
|
value={roleSelectValue(roles?.chat_model_id, chatModels)}
|
||||||
onValueChange={(value) => updateRoles.mutate({ chat_model_id: Number(value) })}
|
onValueChange={(value) => updateRoles.mutate({ chat_model_id: Number(value) })}
|
||||||
>
|
>
|
||||||
<SelectTrigger className="w-full">
|
<SelectTrigger className="w-full">
|
||||||
|
|
@ -373,8 +382,12 @@ export function ModelConnectionsSettings({ searchSpaceId }: { searchSpaceId: num
|
||||||
</div>
|
</div>
|
||||||
<div className="flex flex-col gap-2">
|
<div className="flex flex-col gap-2">
|
||||||
<Label>Vision model</Label>
|
<Label>Vision model</Label>
|
||||||
|
<p className="text-xs text-muted-foreground">
|
||||||
|
Used to understand images in uploads, documents, connectors, and automations. Falls
|
||||||
|
back to chat model when possible.
|
||||||
|
</p>
|
||||||
<Select
|
<Select
|
||||||
value={String(roles?.vision_model_id ?? 0)}
|
value={roleSelectValue(roles?.vision_model_id, visionModels)}
|
||||||
onValueChange={(value) => updateRoles.mutate({ vision_model_id: Number(value) })}
|
onValueChange={(value) => updateRoles.mutate({ vision_model_id: Number(value) })}
|
||||||
>
|
>
|
||||||
<SelectTrigger className="w-full">
|
<SelectTrigger className="w-full">
|
||||||
|
|
@ -388,8 +401,9 @@ export function ModelConnectionsSettings({ searchSpaceId }: { searchSpaceId: num
|
||||||
</div>
|
</div>
|
||||||
<div className="flex flex-col gap-2">
|
<div className="flex flex-col gap-2">
|
||||||
<Label>Image generation model</Label>
|
<Label>Image generation model</Label>
|
||||||
|
<p className="text-xs text-muted-foreground">Used when generating images in chat.</p>
|
||||||
<Select
|
<Select
|
||||||
value={String(roles?.image_gen_model_id ?? 0)}
|
value={roleSelectValue(roles?.image_gen_model_id, imageModels)}
|
||||||
onValueChange={(value) => updateRoles.mutate({ image_gen_model_id: Number(value) })}
|
onValueChange={(value) => updateRoles.mutate({ image_gen_model_id: Number(value) })}
|
||||||
>
|
>
|
||||||
<SelectTrigger className="w-full">
|
<SelectTrigger className="w-full">
|
||||||
|
|
@ -409,8 +423,8 @@ export function ModelConnectionsSettings({ searchSpaceId }: { searchSpaceId: num
|
||||||
<div className="flex flex-col gap-6">
|
<div className="flex flex-col gap-6">
|
||||||
<div className="flex flex-col gap-3">
|
<div className="flex flex-col gap-3">
|
||||||
<div>
|
<div>
|
||||||
<h3 className="text-sm font-semibold">Add Provider</h3>
|
<h3 className="text-base font-semibold">Add Provider</h3>
|
||||||
<p className="text-xs text-muted-foreground">
|
<p className="text-sm text-muted-foreground">
|
||||||
SurfSense supports popular providers and self-hosted model endpoints.
|
SurfSense supports popular providers and self-hosted model endpoints.
|
||||||
</p>
|
</p>
|
||||||
</div>
|
</div>
|
||||||
|
|
@ -462,7 +476,7 @@ export function ModelConnectionsSettings({ searchSpaceId }: { searchSpaceId: num
|
||||||
{connections.length > 0 ? (
|
{connections.length > 0 ? (
|
||||||
<div className="flex flex-col gap-3">
|
<div className="flex flex-col gap-3">
|
||||||
<Separator />
|
<Separator />
|
||||||
<h3 className="text-sm font-semibold">Available Providers</h3>
|
<h3 className="text-base font-semibold">Available Providers</h3>
|
||||||
<div className="flex flex-col gap-3">
|
<div className="flex flex-col gap-3">
|
||||||
{connections.map((connection) => (
|
{connections.map((connection) => (
|
||||||
<ConnectionCard key={connection.id} connection={connection} />
|
<ConnectionCard key={connection.id} connection={connection} />
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue