From adb857925b97fa1508075c3c0d5a83bef31f828d Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Wed, 10 Jun 2026 21:47:23 +0530 Subject: [PATCH 001/212] feat(models): add model connection persistence --- .../versions/156_add_model_connections.py | 178 +++++++++ surfsense_backend/app/db.py | 125 ++++++- surfsense_backend/app/routes/__init__.py | 2 + .../app/routes/model_connections_routes.py | 346 ++++++++++++++++++ surfsense_backend/app/schemas/__init__.py | 10 + .../app/schemas/model_connections.py | 89 +++++ .../app/services/model_connection_service.py | 209 +++++++++++ .../unit/services/test_model_connections.py | 75 ++++ 8 files changed, 1033 insertions(+), 1 deletion(-) create mode 100644 surfsense_backend/alembic/versions/156_add_model_connections.py create mode 100644 surfsense_backend/app/routes/model_connections_routes.py create mode 100644 surfsense_backend/app/schemas/model_connections.py create mode 100644 surfsense_backend/app/services/model_connection_service.py create mode 100644 surfsense_backend/tests/unit/services/test_model_connections.py diff --git a/surfsense_backend/alembic/versions/156_add_model_connections.py b/surfsense_backend/alembic/versions/156_add_model_connections.py new file mode 100644 index 000000000..0a11d7f9d --- /dev/null +++ b/surfsense_backend/alembic/versions/156_add_model_connections.py @@ -0,0 +1,178 @@ +"""add model connections + +Revision ID: 156 +Revises: 155 +""" + +from collections.abc import Sequence + +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +from alembic import op + +revision: str = "156" +down_revision: str | None = "155" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +connection_protocol = postgresql.ENUM( + "OLLAMA", + "OPENAI_COMPATIBLE", + "NATIVE", + name="connectionprotocol", + create_type=False, +) +connection_scope = postgresql.ENUM( + "GLOBAL", + "SEARCH_SPACE", + "USER", + name="connectionscope", + create_type=False, +) +model_source = postgresql.ENUM( + "DISCOVERED", + "MANUAL", + name="modelsource", + create_type=False, +) + + +def upgrade() -> None: + bind = op.get_bind() + connection_protocol.create(bind, checkfirst=True) + connection_scope.create(bind, checkfirst=True) + model_source.create(bind, checkfirst=True) + + op.create_table( + "connections", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("created_at", sa.DateTime(timezone=True), nullable=False), + sa.Column("protocol", connection_protocol, nullable=False), + sa.Column("native_provider", sa.String(length=100), nullable=True), + sa.Column("base_url", sa.String(length=500), nullable=True), + sa.Column("api_key", sa.String(), nullable=True), + sa.Column( + "extra", + postgresql.JSONB(astext_type=sa.Text()), + server_default=sa.text("'{}'::jsonb"), + nullable=False, + ), + sa.Column("scope", connection_scope, nullable=False), + sa.Column("enabled", sa.Boolean(), server_default=sa.text("true"), nullable=False), + sa.Column("search_space_id", sa.Integer(), nullable=True), + sa.Column("user_id", sa.UUID(), nullable=True), + sa.Column("last_verified_at", sa.TIMESTAMP(timezone=True), nullable=True), + sa.Column("last_status", sa.String(length=50), nullable=True), + sa.Column("last_error", sa.Text(), nullable=True), + sa.CheckConstraint( + "(scope = 'GLOBAL' AND search_space_id IS NULL AND user_id IS NULL) OR " + "(scope = 'SEARCH_SPACE' AND search_space_id IS NOT NULL AND user_id IS NOT NULL) OR " + "(scope = 'USER' AND user_id IS NOT NULL)", + name="ck_connections_scope_owner", + ), + sa.ForeignKeyConstraint( + ["search_space_id"], ["searchspaces.id"], ondelete="CASCADE" + ), + sa.ForeignKeyConstraint(["user_id"], ["user.id"], ondelete="CASCADE"), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index(op.f("ix_connections_protocol"), "connections", ["protocol"], unique=False) + op.create_index( + op.f("ix_connections_native_provider"), + "connections", + ["native_provider"], + unique=False, + ) + op.create_index(op.f("ix_connections_scope"), "connections", ["scope"], unique=False) + + op.create_table( + "models", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("created_at", sa.DateTime(timezone=True), nullable=False), + sa.Column("connection_id", sa.Integer(), nullable=False), + sa.Column("model_id", sa.String(length=255), nullable=False), + sa.Column("display_name", sa.String(length=255), nullable=True), + sa.Column( + "source", + model_source, + server_default="DISCOVERED", + nullable=False, + ), + sa.Column( + "capabilities", + postgresql.JSONB(astext_type=sa.Text()), + server_default=sa.text("'{}'::jsonb"), + nullable=False, + ), + sa.Column( + "capabilities_declared", + postgresql.JSONB(astext_type=sa.Text()), + server_default=sa.text("'{}'::jsonb"), + nullable=False, + ), + sa.Column( + "capabilities_verified", + postgresql.JSONB(astext_type=sa.Text()), + server_default=sa.text("'{}'::jsonb"), + nullable=False, + ), + sa.Column( + "capabilities_override", + postgresql.JSONB(astext_type=sa.Text()), + server_default=sa.text("'{}'::jsonb"), + nullable=False, + ), + sa.Column("embedding_dimension", sa.Integer(), nullable=True), + sa.Column("enabled", sa.Boolean(), server_default=sa.text("true"), nullable=False), + sa.Column("billing_tier", sa.String(length=50), nullable=True), + sa.Column( + "catalog", + postgresql.JSONB(astext_type=sa.Text()), + server_default=sa.text("'{}'::jsonb"), + nullable=False, + ), + sa.ForeignKeyConstraint(["connection_id"], ["connections.id"], ondelete="CASCADE"), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint( + "connection_id", "model_id", name="uq_models_connection_model_id" + ), + ) + op.create_index(op.f("ix_models_connection_id"), "models", ["connection_id"], unique=False) + op.create_index("ix_models_model_id", "models", ["model_id"], unique=False) + op.create_index(op.f("ix_models_billing_tier"), "models", ["billing_tier"], unique=False) + + op.add_column( + "searchspaces", + sa.Column("chat_model_id", sa.Integer(), nullable=True), + ) + op.add_column( + "searchspaces", + sa.Column("image_gen_model_id", sa.Integer(), nullable=True), + ) + op.add_column( + "searchspaces", + sa.Column("vision_model_id", sa.Integer(), nullable=True), + ) + + +def downgrade() -> None: + op.drop_column("searchspaces", "vision_model_id") + op.drop_column("searchspaces", "image_gen_model_id") + op.drop_column("searchspaces", "chat_model_id") + + op.drop_index(op.f("ix_models_billing_tier"), table_name="models") + op.drop_index("ix_models_model_id", table_name="models") + op.drop_index(op.f("ix_models_connection_id"), table_name="models") + op.drop_table("models") + + op.drop_index(op.f("ix_connections_scope"), table_name="connections") + op.drop_index(op.f("ix_connections_native_provider"), table_name="connections") + op.drop_index(op.f("ix_connections_protocol"), table_name="connections") + op.drop_table("connections") + + bind = op.get_bind() + model_source.drop(bind, checkfirst=True) + connection_scope.drop(bind, checkfirst=True) + connection_protocol.drop(bind, checkfirst=True) diff --git a/surfsense_backend/app/db.py b/surfsense_backend/app/db.py index 6117caecb..9756cb32f 100644 --- a/surfsense_backend/app/db.py +++ b/surfsense_backend/app/db.py @@ -280,6 +280,23 @@ class VisionProvider(StrEnum): CUSTOM = "CUSTOM" +class ConnectionProtocol(StrEnum): + OLLAMA = "OLLAMA" + OPENAI_COMPATIBLE = "OPENAI_COMPATIBLE" + NATIVE = "NATIVE" + + +class ConnectionScope(StrEnum): + GLOBAL = "GLOBAL" + SEARCH_SPACE = "SEARCH_SPACE" + USER = "USER" + + +class ModelSource(StrEnum): + DISCOVERED = "DISCOVERED" + MANUAL = "MANUAL" + + class LogLevel(StrEnum): DEBUG = "DEBUG" INFO = "INFO" @@ -1642,6 +1659,79 @@ class Report(BaseModel, TimestampMixin): thread = relationship("NewChatThread") +class Connection(BaseModel, TimestampMixin): + __tablename__ = "connections" + + protocol = Column(SQLAlchemyEnum(ConnectionProtocol), nullable=False, index=True) + native_provider = Column(String(100), nullable=True, index=True) + base_url = Column(String(500), nullable=True) + api_key = Column(String, nullable=True) + extra = Column(JSONB, nullable=False, default=dict, server_default="{}") + scope = Column(SQLAlchemyEnum(ConnectionScope), nullable=False, index=True) + enabled = Column(Boolean, nullable=False, default=True, server_default="true") + + search_space_id = Column( + Integer, ForeignKey("searchspaces.id", ondelete="CASCADE"), nullable=True + ) + user_id = Column( + UUID(as_uuid=True), ForeignKey("user.id", ondelete="CASCADE"), nullable=True + ) + + last_verified_at = Column(TIMESTAMP(timezone=True), nullable=True) + last_status = Column(String(50), nullable=True) + last_error = Column(Text, nullable=True) + + search_space = relationship("SearchSpace", back_populates="connections") + user = relationship("User", back_populates="connections") + models = relationship( + "Model", + back_populates="connection", + order_by="Model.id", + cascade="all, delete-orphan", + passive_deletes=True, + ) + + __table_args__ = ( + CheckConstraint( + "(scope = 'GLOBAL' AND search_space_id IS NULL AND user_id IS NULL) OR " + "(scope = 'SEARCH_SPACE' AND search_space_id IS NOT NULL AND user_id IS NOT NULL) OR " + "(scope = 'USER' AND user_id IS NOT NULL)", + name="ck_connections_scope_owner", + ), + ) + + +class Model(BaseModel, TimestampMixin): + __tablename__ = "models" + + connection_id = Column( + Integer, ForeignKey("connections.id", ondelete="CASCADE"), nullable=False, index=True + ) + model_id = Column(String(255), nullable=False) + display_name = Column(String(255), nullable=True) + source = Column( + SQLAlchemyEnum(ModelSource), + nullable=False, + default=ModelSource.DISCOVERED, + server_default=ModelSource.DISCOVERED.value, + ) + capabilities = Column(JSONB, nullable=False, default=dict, server_default="{}") + capabilities_declared = Column(JSONB, nullable=False, default=dict, server_default="{}") + capabilities_verified = Column(JSONB, nullable=False, default=dict, server_default="{}") + capabilities_override = Column(JSONB, nullable=False, default=dict, server_default="{}") + embedding_dimension = Column(Integer, nullable=True) + enabled = Column(Boolean, nullable=False, default=True, server_default="true") + billing_tier = Column(String(50), nullable=True, index=True) + catalog = Column(JSONB, nullable=False, default=dict, server_default="{}") + + connection = relationship("Connection", back_populates="models") + + __table_args__ = ( + UniqueConstraint("connection_id", "model_id", name="uq_models_connection_model_id"), + Index("ix_models_model_id", "model_id"), + ) + + class ImageGenerationConfig(BaseModel, TimestampMixin): """ Dedicated configuration table for image generation models. @@ -1794,7 +1884,7 @@ class SearchSpace(BaseModel, TimestampMixin): # - Positive IDs: Custom configs from DB (NewLLMConfig table) agent_llm_id = Column( Integer, nullable=True, default=0 - ) # For agent/chat operations, defaults to Auto mode + ) # For chat operations, defaults to Auto mode image_generation_config_id = Column( Integer, nullable=True, default=0 ) # For image generation, defaults to Auto mode @@ -1802,6 +1892,22 @@ class SearchSpace(BaseModel, TimestampMixin): Integer, nullable=True, default=0 ) # For vision/screenshot analysis, defaults to Auto mode + # New connection/model role bindings. These supersede the legacy config + # columns above without removing them in this PR. + # Note: ID values preserve the existing convention: + # - 0: Auto mode + # - Negative IDs: Global virtual models from global_llm_config.yaml + # - Positive IDs: User/search-space models from the models table + chat_model_id = Column( + Integer, nullable=True, default=0 + ) # For agent/chat operations, defaults to Auto mode + image_gen_model_id = Column( + Integer, nullable=True, default=0 + ) # For image generation, defaults to Auto mode when eligible + vision_model_id = Column( + Integer, nullable=True, default=0 + ) # For vision/screenshot analysis, defaults to Auto mode + ai_file_sort_enabled = Column( Boolean, nullable=False, default=False, server_default="false" ) @@ -1889,6 +1995,13 @@ class SearchSpace(BaseModel, TimestampMixin): order_by="VisionLLMConfig.id", cascade="all, delete-orphan", ) + connections = relationship( + "Connection", + back_populates="search_space", + order_by="Connection.id", + cascade="all, delete-orphan", + passive_deletes=True, + ) automations = relationship( "Automation", @@ -2429,6 +2542,11 @@ if config.AUTH_TYPE == "GOOGLE": back_populates="user", passive_deletes=True, ) + connections = relationship( + "Connection", + back_populates="user", + passive_deletes=True, + ) # Automations created by this user automations = relationship( @@ -2568,6 +2686,11 @@ else: back_populates="user", passive_deletes=True, ) + connections = relationship( + "Connection", + back_populates="user", + passive_deletes=True, + ) # Automations created by this user automations = relationship( diff --git a/surfsense_backend/app/routes/__init__.py b/surfsense_backend/app/routes/__init__.py index 5cc029884..244208550 100644 --- a/surfsense_backend/app/routes/__init__.py +++ b/surfsense_backend/app/routes/__init__.py @@ -42,6 +42,7 @@ from .linear_add_connector_route import router as linear_add_connector_router from .logs_routes import router as logs_router from .luma_add_connector_route import router as luma_add_connector_router from .mcp_oauth_route import router as mcp_oauth_router +from .model_connections_routes import router as model_connections_router from .memory_routes import router as memory_router from .model_list_routes import router as model_list_router from .new_chat_routes import router as new_chat_router @@ -117,6 +118,7 @@ router.include_router(confluence_add_connector_router) router.include_router(clickup_add_connector_router) router.include_router(dropbox_add_connector_router) router.include_router(new_llm_config_router) # LLM configs with prompt configuration +router.include_router(model_connections_router) # Connection-centric model catalog router.include_router(model_list_router) # Dynamic model catalogue from OpenRouter router.include_router(logs_router) router.include_router(circleback_webhook_router) # Circleback meeting webhooks diff --git a/surfsense_backend/app/routes/model_connections_routes.py b/surfsense_backend/app/routes/model_connections_routes.py new file mode 100644 index 000000000..69910183d --- /dev/null +++ b/surfsense_backend/app/routes/model_connections_routes.py @@ -0,0 +1,346 @@ +import logging + +from fastapi import APIRouter, Depends, HTTPException +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import selectinload + +from app.config import config +from app.db import ( + Connection, + ConnectionScope, + Model, + Permission, + SearchSpace, + User, + get_async_session, +) +from app.schemas import ( + ConnectionCreate, + ConnectionRead, + ConnectionUpdate, + ModelRead, + ModelRolesRead, + ModelRolesUpdate, + ModelUpdate, + VerifyConnectionResponse, +) +from app.services.model_connection_service import ( + discover_models, + persist_verification, + test_model, +) +from app.users import current_active_user +from app.utils.rbac import check_permission + +router = APIRouter() +logger = logging.getLogger(__name__) + + +def _model_read(model: Model | dict) -> ModelRead: + return ModelRead.model_validate(model) + + +def _connection_read(conn: Connection | dict, models: list[Model | dict] | None = None) -> ConnectionRead: + if isinstance(conn, dict): + payload = { + **conn, + "has_api_key": bool(conn.get("api_key")), + "api_key": None, + "models": [_model_read(model) for model in (models or [])], + } + payload.pop("api_key", None) + return ConnectionRead.model_validate(payload) + + return ConnectionRead( + id=conn.id, + protocol=conn.protocol, + native_provider=conn.native_provider, + base_url=conn.base_url, + extra=conn.extra or {}, + scope=conn.scope, + search_space_id=conn.search_space_id, + user_id=conn.user_id, + enabled=conn.enabled, + has_api_key=bool(conn.api_key), + last_verified_at=conn.last_verified_at, + last_status=conn.last_status, + last_error=conn.last_error, + models=[_model_read(model) for model in (models or conn.models or [])], + created_at=conn.created_at, + ) + + +async def _get_search_space(session: AsyncSession, search_space_id: int) -> SearchSpace: + result = await session.execute(select(SearchSpace).where(SearchSpace.id == search_space_id)) + search_space = result.scalars().first() + if not search_space: + raise HTTPException(status_code=404, detail="Search space not found") + return search_space + + +async def _load_connection(session: AsyncSession, connection_id: int) -> Connection: + result = await session.execute( + select(Connection) + .options(selectinload(Connection.models)) + .where(Connection.id == connection_id) + ) + conn = result.scalars().first() + if not conn: + raise HTTPException(status_code=404, detail="Connection not found") + return conn + + +async def _assert_connection_access( + session: AsyncSession, + user: User, + conn: Connection, + permission: str = Permission.LLM_CONFIGS_CREATE.value, +) -> None: + if conn.search_space_id: + await check_permission( + session, + user, + conn.search_space_id, + permission, + "You don't have permission to manage model connections in this search space", + ) + return + if conn.user_id != user.id: + raise HTTPException(status_code=403, detail="Connection does not belong to user") + + +@router.get("/global-model-connections", response_model=list[ConnectionRead]) +async def list_global_connections(user: User = Depends(current_active_user)): + del user + models_by_connection: dict[int, list[dict]] = {} + for model in config.GLOBAL_MODELS: + models_by_connection.setdefault(model["connection_id"], []).append(model) + return [ + _connection_read(conn, models_by_connection.get(conn["id"], [])) + for conn in config.GLOBAL_CONNECTIONS + ] + + +@router.get("/model-connections", response_model=list[ConnectionRead]) +async def list_connections( + search_space_id: int | None = None, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +): + stmt = select(Connection).options(selectinload(Connection.models)) + if search_space_id is not None: + await check_permission( + session, + user, + search_space_id, + Permission.LLM_CONFIGS_CREATE.value, + "You don't have permission to view model connections in this search space", + ) + stmt = stmt.where(Connection.search_space_id == search_space_id) + else: + stmt = stmt.where(Connection.user_id == user.id) + result = await session.execute(stmt.order_by(Connection.id)) + return [_connection_read(conn) for conn in result.scalars().all()] + + +@router.post("/model-connections", response_model=ConnectionRead) +async def create_connection( + data: ConnectionCreate, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +): + if data.scope == ConnectionScope.GLOBAL: + raise HTTPException(status_code=400, detail="GLOBAL connections are YAML-only") + if data.scope == ConnectionScope.SEARCH_SPACE: + if data.search_space_id is None: + raise HTTPException(status_code=400, detail="search_space_id is required") + await check_permission( + session, + user, + data.search_space_id, + Permission.LLM_CONFIGS_CREATE.value, + "You don't have permission to create model connections in this search space", + ) + conn = Connection( + **data.model_dump(exclude={"search_space_id"}), + search_space_id=data.search_space_id if data.scope == ConnectionScope.SEARCH_SPACE else None, + user_id=user.id, + ) + session.add(conn) + await session.commit() + await session.refresh(conn) + return _connection_read(conn) + + +@router.put("/model-connections/{connection_id}", response_model=ConnectionRead) +async def update_connection( + connection_id: int, + data: ConnectionUpdate, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +): + conn = await _load_connection(session, connection_id) + await _assert_connection_access(session, user, conn, Permission.LLM_CONFIGS_UPDATE.value) + for key, value in data.model_dump(exclude_unset=True).items(): + setattr(conn, key, value) + await session.commit() + await session.refresh(conn) + return _connection_read(conn) + + +@router.delete("/model-connections/{connection_id}") +async def delete_connection( + connection_id: int, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +): + conn = await _load_connection(session, connection_id) + await _assert_connection_access(session, user, conn, Permission.LLM_CONFIGS_DELETE.value) + await session.delete(conn) + await session.commit() + return {"status": "deleted"} + + +@router.post("/model-connections/{connection_id}/verify", response_model=VerifyConnectionResponse) +async def verify_model_connection( + connection_id: int, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +): + conn = await _load_connection(session, connection_id) + await _assert_connection_access(session, user, conn, Permission.LLM_CONFIGS_CREATE.value) + result = await persist_verification(conn) + await session.commit() + return VerifyConnectionResponse(status=result.status, ok=result.ok, message=result.message) + + +@router.post("/model-connections/{connection_id}/discover", response_model=list[ModelRead]) +async def discover_connection_models( + connection_id: int, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +): + conn = await _load_connection(session, connection_id) + await _assert_connection_access(session, user, conn, Permission.LLM_CONFIGS_CREATE.value) + discovered = await discover_models(conn) + by_model_id = {model.model_id: model for model in conn.models} + for item in discovered: + db_model = by_model_id.get(item["model_id"]) + if db_model is None: + db_model = Model( + connection_id=conn.id, + model_id=item["model_id"], + display_name=item.get("display_name"), + source=item["source"], + capabilities=item["capabilities"], + capabilities_declared=item["capabilities"], + capabilities_verified={}, + capabilities_override={}, + enabled=False, + catalog=item.get("metadata") or {}, + ) + session.add(db_model) + else: + db_model.display_name = item.get("display_name") or db_model.display_name + db_model.capabilities_declared = item["capabilities"] + db_model.capabilities = { + **item["capabilities"], + **(db_model.capabilities_override or {}), + } + db_model.catalog = item.get("metadata") or db_model.catalog + await session.commit() + await session.refresh(conn) + return [_model_read(model) for model in conn.models] + + +@router.put("/models/{model_id}", response_model=ModelRead) +async def update_model( + model_id: int, + data: ModelUpdate, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +): + result = await session.execute( + select(Model).options(selectinload(Model.connection)).where(Model.id == model_id) + ) + model = result.scalars().first() + if not model: + raise HTTPException(status_code=404, detail="Model not found") + await _assert_connection_access(session, user, model.connection, Permission.LLM_CONFIGS_UPDATE.value) + update = data.model_dump(exclude_unset=True) + for key, value in update.items(): + setattr(model, key, value) + if "capabilities_override" in update: + model.capabilities = { + **(model.capabilities_declared or {}), + **(model.capabilities_override or {}), + } + await session.commit() + await session.refresh(model) + return _model_read(model) + + +@router.post("/models/{model_id}/test", response_model=VerifyConnectionResponse) +async def test_connection_model( + model_id: int, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +): + result = await session.execute( + select(Model).options(selectinload(Model.connection)).where(Model.id == model_id) + ) + model = result.scalars().first() + if not model: + raise HTTPException(status_code=404, detail="Model not found") + await _assert_connection_access(session, user, model.connection, Permission.LLM_CONFIGS_UPDATE.value) + result = await test_model(model.connection, model) + await session.commit() + return VerifyConnectionResponse(status=result.status, ok=result.ok, message=result.message) + + +@router.get("/search-spaces/{search_space_id}/model-roles", response_model=ModelRolesRead) +async def get_model_roles( + search_space_id: int, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +): + await check_permission( + session, + user, + search_space_id, + Permission.LLM_CONFIGS_CREATE.value, + "You don't have permission to view model roles in this search space", + ) + search_space = await _get_search_space(session, search_space_id) + return ModelRolesRead( + chat_model_id=search_space.chat_model_id, + vision_model_id=search_space.vision_model_id, + image_gen_model_id=search_space.image_gen_model_id, + ) + + +@router.put("/search-spaces/{search_space_id}/model-roles", response_model=ModelRolesRead) +async def update_model_roles( + search_space_id: int, + data: ModelRolesUpdate, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +): + await check_permission( + session, + user, + search_space_id, + Permission.LLM_CONFIGS_UPDATE.value, + "You don't have permission to update model roles in this search space", + ) + search_space = await _get_search_space(session, search_space_id) + for key, value in data.model_dump(exclude_unset=True).items(): + setattr(search_space, key, value) + await session.commit() + await session.refresh(search_space) + return ModelRolesRead( + chat_model_id=search_space.chat_model_id, + vision_model_id=search_space.vision_model_id, + image_gen_model_id=search_space.image_gen_model_id, + ) diff --git a/surfsense_backend/app/schemas/__init__.py b/surfsense_backend/app/schemas/__init__.py index fdf34672b..c14671c99 100644 --- a/surfsense_backend/app/schemas/__init__.py +++ b/surfsense_backend/app/schemas/__init__.py @@ -44,6 +44,16 @@ from .image_generation import ( ImageGenerationRead, ) from .logs import LogBase, LogCreate, LogFilter, LogRead, LogUpdate +from .model_connections import ( + ConnectionCreate, + ConnectionRead, + ConnectionUpdate, + ModelRead, + ModelRolesRead, + ModelRolesUpdate, + ModelUpdate, + VerifyConnectionResponse, +) from .new_chat import ( ChatMessage, NewChatMessageAppend, diff --git a/surfsense_backend/app/schemas/model_connections.py b/surfsense_backend/app/schemas/model_connections.py new file mode 100644 index 000000000..731064375 --- /dev/null +++ b/surfsense_backend/app/schemas/model_connections.py @@ -0,0 +1,89 @@ +import uuid +from datetime import datetime +from typing import Any + +from pydantic import BaseModel, ConfigDict, Field + +from app.db import ConnectionProtocol, ConnectionScope, ModelSource + + +class ModelRead(BaseModel): + id: int + connection_id: int + model_id: str + display_name: str | None = None + source: ModelSource | str + capabilities: dict[str, Any] + capabilities_declared: dict[str, Any] = Field(default_factory=dict) + capabilities_verified: dict[str, Any] = Field(default_factory=dict) + capabilities_override: dict[str, Any] = Field(default_factory=dict) + embedding_dimension: int | None = None + enabled: bool + billing_tier: str | None = None + catalog: dict[str, Any] = Field(default_factory=dict) + created_at: datetime | None = None + + model_config = ConfigDict(from_attributes=True) + + +class ConnectionRead(BaseModel): + id: int + protocol: ConnectionProtocol | str + native_provider: str | None = None + base_url: str | None = None + extra: dict[str, Any] = Field(default_factory=dict) + scope: ConnectionScope | str + search_space_id: int | None = None + user_id: uuid.UUID | None = None + enabled: bool + has_api_key: bool + last_verified_at: datetime | None = None + last_status: str | None = None + last_error: str | None = None + models: list[ModelRead] = Field(default_factory=list) + created_at: datetime | None = None + + model_config = ConfigDict(from_attributes=True) + + +class ConnectionCreate(BaseModel): + protocol: ConnectionProtocol + native_provider: str | None = None + base_url: str | None = Field(None, max_length=500) + api_key: str | None = None + extra: dict[str, Any] = Field(default_factory=dict) + scope: ConnectionScope = ConnectionScope.SEARCH_SPACE + search_space_id: int | None = None + enabled: bool = True + + +class ConnectionUpdate(BaseModel): + native_provider: str | None = None + base_url: str | None = Field(None, max_length=500) + api_key: str | None = None + extra: dict[str, Any] | None = None + enabled: bool | None = None + + +class ModelUpdate(BaseModel): + display_name: str | None = Field(None, max_length=255) + enabled: bool | None = None + capabilities_override: dict[str, Any] | None = None + + +class VerifyConnectionResponse(BaseModel): + status: str + ok: bool + message: str = "" + + +class ModelRolesRead(BaseModel): + chat_model_id: int | None = 0 + vision_model_id: int | None = 0 + image_gen_model_id: int | None = 0 + + +class ModelRolesUpdate(BaseModel): + chat_model_id: int | None = None + vision_model_id: int | None = None + image_gen_model_id: int | None = None diff --git a/surfsense_backend/app/services/model_connection_service.py b/surfsense_backend/app/services/model_connection_service.py new file mode 100644 index 000000000..81090acaf --- /dev/null +++ b/surfsense_backend/app/services/model_connection_service.py @@ -0,0 +1,209 @@ +"""Connection verification, model discovery, and capability probing.""" + +from __future__ import annotations + +import contextlib +import logging +from dataclasses import dataclass +from datetime import UTC, datetime +from typing import Any + +import httpx +import litellm + +from app.db import Connection, ConnectionProtocol, Model, ModelSource +from app.services.model_resolver import ensure_v1, to_litellm + +logger = logging.getLogger(__name__) + +VERIFY_TIMEOUT_SECONDS = 8.0 +DISCOVERY_TIMEOUT_SECONDS = 15.0 +TEST_TIMEOUT_SECONDS = 30.0 + + +@dataclass(frozen=True) +class VerifyResult: + status: str + ok: bool + message: str = "" + + +def _auth_headers(conn: Connection) -> dict[str, str]: + if not conn.api_key: + return {} + return {"Authorization": f"Bearer {conn.api_key}"} + + +def _docker_hint(url: str | None, exc_or_status: Any) -> str: + raw = str(exc_or_status) + if not url: + return raw + if "localhost" in url or "127.0.0.1" in url: + return ( + f"{raw}. The backend is running inside Docker; localhost means the " + "backend container. Use host.docker.internal and make sure the model " + "server listens on 0.0.0.0." + ) + if "host.docker.internal" in url and ("refused" in raw.lower() or "connect" in raw.lower()): + return ( + f"{raw}. The host is reachable only if your local model server is " + "listening on 0.0.0.0. On Linux Docker, add " + "`host.docker.internal:host-gateway` to extra_hosts." + ) + return raw + + +async def verify_connection(conn: Connection) -> VerifyResult: + if not conn.base_url and conn.protocol in ( + ConnectionProtocol.OLLAMA, + ConnectionProtocol.OPENAI_COMPATIBLE, + ): + return VerifyResult("UNREACHABLE", False, "Base URL is required.") + + if conn.protocol == ConnectionProtocol.OLLAMA: + url = f"{conn.base_url.rstrip('/')}/api/version" + elif conn.protocol == ConnectionProtocol.OPENAI_COMPATIBLE: + url = f"{ensure_v1(conn.base_url)}/models" + else: + # Native providers do not share one cheap health endpoint. The model + # probe exercises the real path and is the authoritative check. + return VerifyResult("OK", True, "Native provider configuration accepted.") + + try: + async with httpx.AsyncClient(timeout=VERIFY_TIMEOUT_SECONDS) as client: + response = await client.get(url, headers=_auth_headers(conn)) + if response.status_code in (401, 403): + return VerifyResult("AUTH_FAILED", False, "Authentication failed.") + if response.status_code == 404: + if conn.protocol == ConnectionProtocol.OLLAMA and url.endswith("/v1/models"): + message = "Ollama native API should not use /v1." + elif conn.protocol == ConnectionProtocol.OPENAI_COMPATIBLE: + message = "OpenAI-compatible servers should expose /v1/models." + else: + message = "Endpoint returned 404." + return VerifyResult("NOT_FOUND", False, message) + response.raise_for_status() + return VerifyResult("OK", True, "Connection verified.") + except httpx.ConnectError as exc: + return VerifyResult("UNREACHABLE", False, _docker_hint(conn.base_url, exc)) + except httpx.TimeoutException as exc: + return VerifyResult("UNREACHABLE", False, f"Connection timed out: {exc}") + except httpx.HTTPError as exc: + return VerifyResult("UNREACHABLE", False, _docker_hint(conn.base_url, exc)) + + +async def persist_verification(conn: Connection) -> VerifyResult: + result = await verify_connection(conn) + conn.last_verified_at = datetime.now(UTC) + conn.last_status = result.status + conn.last_error = "" if result.ok else result.message + return result + + +def _litellm_capabilities(model_string: str, model_id: str) -> dict[str, bool]: + capabilities = { + "chat": True, + "vision": False, + "tools": False, + "image_gen": False, + "embedding": False, + } + with contextlib.suppress(Exception): + capabilities["vision"] = bool(litellm.supports_vision(model=model_string)) + with contextlib.suppress(Exception): + capabilities["tools"] = bool(litellm.supports_function_calling(model=model_string)) + try: + info = litellm.model_cost.get(model_string) or litellm.model_cost.get(model_id) or {} + mode = str(info.get("mode") or "") + capabilities["embedding"] = mode == "embedding" + capabilities["image_gen"] = mode in {"image_generation", "image_generation_model"} + except Exception: + pass + return capabilities + + +def derive_capabilities(conn: Connection, model_id: str, metadata: dict | None = None) -> dict[str, bool]: + metadata = metadata or {} + if conn.protocol == ConnectionProtocol.OLLAMA: + caps = metadata.get("capabilities") or [] + capabilities = { + "chat": True, + "vision": "vision" in caps, + "tools": False, + "image_gen": False, + "embedding": "embedding" in caps, + } + return capabilities + + model_string, _ = to_litellm(conn, model_id) + return _litellm_capabilities(model_string, model_id) + + +async def discover_models(conn: Connection) -> list[dict[str, Any]]: + if conn.protocol == ConnectionProtocol.OLLAMA: + url = f"{conn.base_url.rstrip('/')}/api/tags" + async with httpx.AsyncClient(timeout=DISCOVERY_TIMEOUT_SECONDS) as client: + response = await client.get(url, headers=_auth_headers(conn)) + response.raise_for_status() + models = response.json().get("models", []) + return [ + { + "model_id": item.get("model") or item.get("name"), + "display_name": item.get("name") or item.get("model"), + "source": ModelSource.DISCOVERED, + "capabilities": derive_capabilities(conn, item.get("model") or item.get("name"), item), + "metadata": item, + } + for item in models + if item.get("model") or item.get("name") + ] + + if conn.protocol == ConnectionProtocol.OPENAI_COMPATIBLE: + url = f"{ensure_v1(conn.base_url)}/models" + async with httpx.AsyncClient(timeout=DISCOVERY_TIMEOUT_SECONDS) as client: + response = await client.get(url, headers=_auth_headers(conn)) + response.raise_for_status() + models = response.json().get("data", []) + return [ + { + "model_id": item.get("id"), + "display_name": item.get("name") or item.get("id"), + "source": ModelSource.DISCOVERED, + "capabilities": derive_capabilities(conn, item.get("id"), item), + "metadata": item, + } + for item in models + if item.get("id") + ] + + # Native providers rely on curated/global catalog entries or manual rows. + return [] + + +async def test_model(conn: Connection, model: Model) -> VerifyResult: + model_string, kwargs = to_litellm(conn, model.model_id) + try: + await litellm.acompletion( + model=model_string, + messages=[{"role": "user", "content": "Hello"}], + timeout=TEST_TIMEOUT_SECONDS, + **kwargs, + ) + except Exception as exc: + return VerifyResult("UNREACHABLE", False, str(exc)) + + model.capabilities_verified = { + **(model.capabilities_verified or {}), + "chat": True, + } + return VerifyResult("OK", True, "Model test succeeded.") + + +__all__ = [ + "VerifyResult", + "derive_capabilities", + "discover_models", + "persist_verification", + "test_model", + "verify_connection", +] diff --git a/surfsense_backend/tests/unit/services/test_model_connections.py b/surfsense_backend/tests/unit/services/test_model_connections.py new file mode 100644 index 000000000..98042501b --- /dev/null +++ b/surfsense_backend/tests/unit/services/test_model_connections.py @@ -0,0 +1,75 @@ +from app.services.global_model_catalog import materialize_global_model_catalog +from app.services.model_resolver import ensure_v1, to_litellm + + +def test_openai_compatible_resolver_normalizes_v1() -> None: + model, kwargs = to_litellm( + { + "protocol": "OPENAI_COMPATIBLE", + "base_url": "http://host.docker.internal:1234", + "api_key": "local-key", + "extra": {}, + }, + "qwen/qwen3", + ) + + assert model == "openai/qwen/qwen3" + assert kwargs["api_base"] == "http://host.docker.internal:1234/v1" + assert kwargs["api_key"] == "local-key" + assert ensure_v1("http://example.com/v1") == "http://example.com/v1" + + +def test_ollama_resolver_uses_native_api_base() -> None: + model, kwargs = to_litellm( + { + "protocol": "OLLAMA", + "base_url": "http://host.docker.internal:11434", + "api_key": None, + "extra": {}, + }, + "llama3.2", + ) + + assert model == "ollama_chat/llama3.2" + assert kwargs["api_base"] == "http://host.docker.internal:11434" + + +def test_global_materialization_preserves_tier_and_keeps_key_server_side() -> None: + connections, models = materialize_global_model_catalog( + chat_configs=[ + { + "id": -101, + "name": "OpenRouter Free", + "provider": "OPENROUTER", + "model_name": "meta-llama/llama-3.1-8b-instruct:free", + "api_key": "sk-global-secret", + "billing_tier": "free", + "anonymous_enabled": True, + "seo_enabled": True, + "rpm": 10, + "tpm": 1000, + }, + { + "id": -102, + "name": "OpenRouter Premium", + "provider": "OPENROUTER", + "model_name": "anthropic/claude-sonnet-4", + "api_key": "sk-global-secret", + "billing_tier": "premium", + }, + ], + vision_configs=[], + image_configs=[], + ) + + assert len(connections) == 1 + assert connections[0]["api_key"] == "sk-global-secret" + assert {model["billing_tier"] for model in models} == {"free", "premium"} + assert models[0]["catalog"]["anonymous_enabled"] is True + assert models[0]["catalog"]["rpm"] == 10 + + public_connections = [ + {key: value for key, value in connection.items() if key != "api_key"} + for connection in connections + ] + assert "sk-" not in repr(public_connections) From 8b59ca59c11e6105d6b2558b73657766bf726869 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Wed, 10 Jun 2026 21:47:42 +0530 Subject: [PATCH 002/212] feat(models): add provider catalog and resolver --- surfsense_backend/app/config/__init__.py | 33 +++- .../app/config/global_llm_config.example.yaml | 45 +++--- .../app/schemas/new_llm_config.py | 2 +- .../app/services/global_model_catalog.py | 142 ++++++++++++++++ .../app/services/model_resolver.py | 152 ++++++++++++++++++ .../app/services/provider_capabilities.py | 37 +---- 6 files changed, 355 insertions(+), 56 deletions(-) create mode 100644 surfsense_backend/app/services/global_model_catalog.py create mode 100644 surfsense_backend/app/services/model_resolver.py diff --git a/surfsense_backend/app/config/__init__.py b/surfsense_backend/app/config/__init__.py index 75af17d11..b8addb45d 100644 --- a/surfsense_backend/app/config/__init__.py +++ b/surfsense_backend/app/config/__init__.py @@ -389,10 +389,28 @@ def initialize_openrouter_integration(): ) except Exception as e: print(f"Warning: Failed to inject OpenRouter vision-LLM configs: {e}") + + refresh_global_model_catalog() except Exception as e: print(f"Warning: Failed to initialize OpenRouter integration: {e}") +def materialize_global_configs(): + from app.services.global_model_catalog import materialize_global_model_catalog + + return materialize_global_model_catalog( + chat_configs=getattr(config, "GLOBAL_LLM_CONFIGS", []), + vision_configs=getattr(config, "GLOBAL_VISION_LLM_CONFIGS", []), + image_configs=getattr(config, "GLOBAL_IMAGE_GEN_CONFIGS", []), + ) + + +def refresh_global_model_catalog(): + connections, models = materialize_global_configs() + config.GLOBAL_CONNECTIONS = connections + config.GLOBAL_MODELS = models + + def initialize_pricing_registration(): """ Teach LiteLLM the per-token cost of every deployment in @@ -723,7 +741,7 @@ class Config: os.getenv("QUOTA_DEFAULT_IMAGE_RESERVE_MICROS", "50000") ) - # Per-podcast reservation (in micro-USD). One agent LLM call generating + # Per-podcast reservation (in micro-USD). One chat model call generating # a transcript, typically 5k-20k completion tokens. $0.20 covers a long # premium-model run. Tune via env. QUOTA_DEFAULT_PODCAST_RESERVE_MICROS = int( @@ -849,6 +867,19 @@ class Config: # Router settings for Vision LLM Auto mode VISION_LLM_ROUTER_SETTINGS = load_vision_llm_router_settings() + # Virtual GLOBAL connection/model catalog. This is server-only metadata + # derived from global_llm_config.yaml; GLOBAL keys are not stored in DB. + from app.services.global_model_catalog import ( + materialize_global_model_catalog as _materialize_global_model_catalog, + ) + + GLOBAL_CONNECTIONS, GLOBAL_MODELS = _materialize_global_model_catalog( + chat_configs=GLOBAL_LLM_CONFIGS, + vision_configs=GLOBAL_VISION_LLM_CONFIGS, + image_configs=GLOBAL_IMAGE_GEN_CONFIGS, + ) + del _materialize_global_model_catalog + # OpenRouter Integration settings (optional) OPENROUTER_INTEGRATION_SETTINGS = load_openrouter_integration_settings() diff --git a/surfsense_backend/app/config/global_llm_config.example.yaml b/surfsense_backend/app/config/global_llm_config.example.yaml index 1c09a91ac..b0eee6458 100644 --- a/surfsense_backend/app/config/global_llm_config.example.yaml +++ b/surfsense_backend/app/config/global_llm_config.example.yaml @@ -7,8 +7,9 @@ # NOTE: The example API keys below are placeholders and won't work. # Replace them with your actual API keys to enable global configurations. # -# These configurations will be available to all users as a convenient option -# Users can choose to use these global configs or add their own +# These configurations are materialized as server-owned GLOBAL connections/models +# and become available on the Models page. Users can choose hosted/global models +# or add their own BYOK/local connections. # # AUTO MODE (Recommended): # - Auto mode (ID: 0) uses LiteLLM Router to automatically load balance across all global configs @@ -16,9 +17,12 @@ # - New users are automatically assigned Auto mode by default # - Configure router_settings below to customize the load balancing behavior # -# Structure matches NewLLMConfig: -# - Model configuration (provider, model_name, api_key, etc.) -# - Prompt configuration (system_instructions, citations_enabled) +# Static config shape: +# - Connection fields: provider, api_key, api_base, api_version +# - Model fields: model_name, billing_tier, rpm/tpm, litellm_params +# - Prompt defaults: system_instructions, citations_enabled +# IDs share one GLOBAL model namespace across chat, vision, and image generation. +# Suggested ranges: chat -1..-999, vision -1001..-1999, image -2001..-2999. # # COST-BASED PREMIUM CREDITS: # Each premium config bills the user's USD-credit balance based on the @@ -327,7 +331,7 @@ openrouter_integration: quota_reserve_tokens: 4000 # id_offset: base negative ID for dynamically generated configs. # Model IDs are derived deterministically via BLAKE2b so they survive - # catalogue churn. Must not overlap with your static global_llm_configs IDs. + # catalogue churn. Must not overlap with any static GLOBAL model IDs. id_offset: -10000 # refresh_interval_hours: how often to re-fetch models from OpenRouter (0 = startup only) refresh_interval_hours: 24 @@ -351,8 +355,8 @@ openrouter_integration: # Image generation + vision LLM emission are OPT-IN. OpenRouter's catalogue # contains hundreds of image- and vision-capable models; turning these on - # injects them into the global Image-Generation / Vision-LLM model - # selectors alongside any static configs. Tier (free/premium) is derived + # injects them into the global image-generation / vision model lists + # alongside any static configs. Tier (free/premium) is derived # per model the same way it is for chat (`:free` suffix or zero pricing). # When a user picks a premium image/vision model the call debits the # shared $5 USD-cost-based premium credit pool — so leaving these off @@ -384,7 +388,7 @@ image_generation_router_settings: global_image_generation_configs: # Example: OpenAI DALL-E 3 - - id: -1 + - id: -2001 name: "Global DALL-E 3" description: "OpenAI's DALL-E 3 for high-quality image generation" provider: "OPENAI" @@ -395,7 +399,7 @@ global_image_generation_configs: litellm_params: {} # Example: OpenAI GPT Image 1 - - id: -2 + - id: -2002 name: "Global GPT Image 1" description: "OpenAI's GPT Image 1 model" provider: "OPENAI" @@ -406,7 +410,7 @@ global_image_generation_configs: litellm_params: {} # Example: Azure OpenAI DALL-E 3 - - id: -3 + - id: -2003 name: "Global Azure DALL-E 3" description: "Azure-hosted DALL-E 3 deployment" provider: "AZURE_OPENAI" @@ -419,7 +423,7 @@ global_image_generation_configs: base_model: "dall-e-3" # Example: OpenRouter Gemini Image Generation - # - id: -4 + # - id: -2004 # name: "Global Gemini Image Gen" # description: "Google Gemini image generation via OpenRouter" # provider: "OPENROUTER" @@ -448,7 +452,7 @@ vision_llm_router_settings: global_vision_llm_configs: # Example: OpenAI GPT-4o (recommended for vision) - - id: -1 + - id: -1001 name: "Global GPT-4o Vision" description: "OpenAI's GPT-4o with strong vision capabilities" provider: "OPENAI" @@ -462,7 +466,7 @@ global_vision_llm_configs: max_tokens: 1000 # Example: Google Gemini 2.0 Flash - - id: -2 + - id: -1002 name: "Global Gemini 2.0 Flash" description: "Google's fast vision model with large context" provider: "GOOGLE" @@ -476,7 +480,7 @@ global_vision_llm_configs: max_tokens: 1000 # Example: Anthropic Claude 3.5 Sonnet - - id: -3 + - id: -1003 name: "Global Claude 3.5 Sonnet Vision" description: "Anthropic's Claude 3.5 Sonnet with vision support" provider: "ANTHROPIC" @@ -490,7 +494,7 @@ global_vision_llm_configs: max_tokens: 1000 # Example: Azure OpenAI GPT-4o - # - id: -4 + # - id: -1004 # name: "Global Azure GPT-4o Vision" # description: "Azure-hosted GPT-4o for vision analysis" # provider: "AZURE_OPENAI" @@ -507,8 +511,9 @@ global_vision_llm_configs: # Notes: # - ID 0 is reserved for "Auto" mode - uses LiteLLM Router for load balancing -# - Use negative IDs to distinguish global configs from user configs (NewLLMConfig in DB) -# - IDs should be unique and sequential (e.g., -1, -2, -3, etc.) +# - Use negative IDs to distinguish global models from BYOK/local DB models +# - IDs must be unique across chat, vision, and image generation configs +# - Suggested static ranges: chat -1..-999, vision -1001..-1999, image -2001..-2999 # - The 'api_key' field will not be exposed to users via API # - system_instructions: Custom prompt or empty string to use defaults # - use_default_system_instructions: true = use SURFSENSE_SYSTEM_INSTRUCTIONS when system_instructions is empty @@ -519,7 +524,7 @@ global_vision_llm_configs: # # # IMAGE GENERATION NOTES: -# - Image generation configs use the same ID scheme as LLM configs (negative for global) +# - Image generation configs use the shared GLOBAL ID namespace # - Supported models: dall-e-2, dall-e-3, gpt-image-1 (OpenAI), azure/* (Azure), # bedrock/* (AWS), vertex_ai/* (Google), recraft/* (Recraft), openrouter/* (OpenRouter) # - The router uses litellm.aimage_generation() for async image generation @@ -527,7 +532,7 @@ global_vision_llm_configs: # TPM (tokens per minute) does not apply since image APIs are billed/rate-limited per request, not per token. # # VISION LLM NOTES: -# - Vision configs use the same ID scheme (negative for global, positive for user DB) +# - Vision configs use the shared GLOBAL ID namespace # - Only use vision-capable models (GPT-4o, Gemini, Claude 3, etc.) # - Lower temperature (0.3) is recommended for accurate screenshot analysis # - Lower max_tokens (1000) is sufficient since autocomplete produces short suggestions diff --git a/surfsense_backend/app/schemas/new_llm_config.py b/surfsense_backend/app/schemas/new_llm_config.py index 716aa0457..2f04a9e66 100644 --- a/surfsense_backend/app/schemas/new_llm_config.py +++ b/surfsense_backend/app/schemas/new_llm_config.py @@ -229,7 +229,7 @@ class LLMPreferencesRead(BaseModel): description="ID of the vision LLM config to use for vision/screenshot analysis", ) agent_llm: dict[str, Any] | None = Field( - None, description="Full config for agent LLM" + None, description="Full config for chat model" ) image_generation_config: dict[str, Any] | None = Field( None, description="Full config for image generation" diff --git a/surfsense_backend/app/services/global_model_catalog.py b/surfsense_backend/app/services/global_model_catalog.py new file mode 100644 index 000000000..a43f58b9e --- /dev/null +++ b/surfsense_backend/app/services/global_model_catalog.py @@ -0,0 +1,142 @@ +"""Materialize server-owned GLOBAL YAML configs as virtual connections/models.""" + +from __future__ import annotations + +from typing import Any + +from app.services.model_resolver import native_connection_from_config + + +def _base_model(config: dict[str, Any]) -> str | None: + litellm_params = config.get("litellm_params") or {} + if isinstance(litellm_params, dict): + return litellm_params.get("base_model") + return None + + +def _connection_key(conn: dict[str, Any]) -> tuple[Any, ...]: + # Deliberately includes api_key because two operator-owned credentials for + # the same provider/base can have different quota/rate limits upstream. + return ( + conn.get("protocol"), + conn.get("native_provider"), + conn.get("base_url"), + conn.get("api_key"), + _freeze(conn.get("extra") or {}), + ) + + +def _freeze(value: Any) -> Any: + if isinstance(value, dict): + return tuple(sorted((key, _freeze(val)) for key, val in value.items())) + if isinstance(value, list): + return tuple(_freeze(item) for item in value) + return value + + +def _capabilities_for(role: str, config: dict[str, Any]) -> dict[str, bool]: + return { + "chat": role == "chat", + "vision": role == "vision" or bool(config.get("supports_image_input")), + "image_gen": role == "image_gen", + "embedding": False, + "tools": bool(config.get("supports_tools", False)), + } + + +def _catalog_metadata(config: dict[str, Any]) -> dict[str, Any]: + return { + "billing_tier": config.get("billing_tier", "free"), + "quota_reserve_tokens": config.get("quota_reserve_tokens"), + "rpm": config.get("rpm"), + "tpm": config.get("tpm"), + "anonymous_enabled": config.get("anonymous_enabled", False), + "seo_enabled": config.get("seo_enabled", False), + "seo_slug": config.get("seo_slug"), + "input_cost_per_token": (config.get("litellm_params") or {}).get( + "input_cost_per_token" + ) + if isinstance(config.get("litellm_params"), dict) + else None, + "output_cost_per_token": (config.get("litellm_params") or {}).get( + "output_cost_per_token" + ) + if isinstance(config.get("litellm_params"), dict) + else None, + "is_planner": config.get("is_planner", False), + "base_model": _base_model(config), + "router_pool_eligible": config.get("router_pool_eligible", True), + } + + +def materialize_global_model_catalog( + *, + chat_configs: list[dict[str, Any]], + vision_configs: list[dict[str, Any]], + image_configs: list[dict[str, Any]], +) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]: + connections: list[dict[str, Any]] = [] + models: list[dict[str, Any]] = [] + connection_id_by_key: dict[tuple[Any, ...], int] = {} + next_connection_id = -1 + + def add_config(config: dict[str, Any], role: str) -> None: + nonlocal next_connection_id + if not config.get("id") or not config.get("model_name"): + return + conn = native_connection_from_config(config) + conn["scope"] = "GLOBAL" + conn["enabled"] = True + conn["last_status"] = "OK" + key = _connection_key(conn) + connection_id = connection_id_by_key.get(key) + if connection_id is None: + connection_id = next_connection_id + next_connection_id -= 1 + connection_id_by_key[key] = connection_id + connections.append( + { + "id": connection_id, + **conn, + } + ) + + model_id = int(config["id"]) + models.append( + { + "id": model_id, + "connection_id": connection_id, + "model_id": config["model_name"], + "display_name": config.get("name") or config["model_name"], + "source": "MANUAL", + "capabilities": _capabilities_for(role, config), + "capabilities_declared": _capabilities_for(role, config), + "capabilities_verified": _capabilities_for(role, config), + "capabilities_override": {}, + "embedding_dimension": None, + "enabled": True, + "billing_tier": config.get("billing_tier", "free"), + "catalog": _catalog_metadata(config), + "role": role, + } + ) + + for cfg in chat_configs: + if cfg.get("is_auto_mode"): + continue + add_config(cfg, "chat") + for cfg in vision_configs: + if cfg.get("is_auto_mode"): + continue + add_config(cfg, "vision") + for cfg in image_configs: + if cfg.get("is_auto_mode"): + continue + add_config(cfg, "image_gen") + + # Each virtual connection is server-only. Callers that serialize these + # must strip api_key before returning data to clients. + return connections, models + + +__all__ = ["materialize_global_model_catalog"] diff --git a/surfsense_backend/app/services/model_resolver.py b/surfsense_backend/app/services/model_resolver.py new file mode 100644 index 000000000..ec485a5ae --- /dev/null +++ b/surfsense_backend/app/services/model_resolver.py @@ -0,0 +1,152 @@ +"""Single model-to-LiteLLM resolver. + +All chat, vision, image-generation, validation, and Auto routing paths should +turn a Connection + Model into LiteLLM input through this module. +""" + +from __future__ import annotations + +from collections.abc import Mapping +from typing import TYPE_CHECKING, Any + +from app.services.provider_api_base import resolve_api_base + +if TYPE_CHECKING: + from app.db import Connection + +PROTOCOL_OLLAMA = "OLLAMA" +PROTOCOL_OPENAI_COMPATIBLE = "OPENAI_COMPATIBLE" +PROTOCOL_NATIVE = "NATIVE" + +NATIVE_PROVIDER_PREFIX: dict[str, str] = { + "OPENAI": "openai", + "ANTHROPIC": "anthropic", + "GROQ": "groq", + "COHERE": "cohere", + "GOOGLE": "gemini", + "MISTRAL": "mistral", + "AZURE_OPENAI": "azure", + "AZURE": "azure", + "OPENROUTER": "openrouter", + "COMETAPI": "cometapi", + "XAI": "xai", + "BEDROCK": "bedrock", + "AWS_BEDROCK": "bedrock", + "VERTEX_AI": "vertex_ai", + "TOGETHER_AI": "together_ai", + "FIREWORKS_AI": "fireworks_ai", + "DEEPSEEK": "openai", + "ALIBABA_QWEN": "openai", + "MOONSHOT": "openai", + "ZHIPU": "openai", + "GITHUB_MODELS": "github", + "REPLICATE": "replicate", + "PERPLEXITY": "perplexity", + "ANYSCALE": "anyscale", + "DEEPINFRA": "deepinfra", + "CEREBRAS": "cerebras", + "SAMBANOVA": "sambanova", + "AI21": "ai21", + "CLOUDFLARE": "cloudflare", + "DATABRICKS": "databricks", + "HUGGINGFACE": "huggingface", + "MINIMAX": "openai", + "RECRAFT": "recraft", + "XINFERENCE": "xinference", + "NSCALE": "nscale", + "CUSTOM": "custom", +} + + +def ensure_v1(base_url: str | None) -> str | None: + if not base_url: + return None + stripped = base_url.rstrip("/") + if stripped.endswith("/v1"): + return stripped + return f"{stripped}/v1" + + +def _conn_value(conn: Connection | Mapping[str, Any], key: str) -> Any: + if isinstance(conn, Mapping): + return conn.get(key) + return getattr(conn, key) + + +def _protocol_value(protocol: Any) -> str: + return getattr(protocol, "value", str(protocol)) + + +def to_litellm( + conn: Connection | Mapping[str, Any], + model_id: str, +) -> tuple[str, dict[str, Any]]: + """Return ``(model_string, litellm_kwargs)`` for any model role.""" + protocol = _protocol_value(_conn_value(conn, "protocol")) + base_url = _conn_value(conn, "base_url") + api_key = _conn_value(conn, "api_key") + native_provider = _conn_value(conn, "native_provider") + extra = _conn_value(conn, "extra") or {} + + kwargs: dict[str, Any] = {} + if api_key: + kwargs["api_key"] = api_key + + if protocol == PROTOCOL_OLLAMA: + model_string = f"ollama_chat/{model_id}" + if base_url: + kwargs["api_base"] = base_url.rstrip("/") + elif protocol == PROTOCOL_OPENAI_COMPATIBLE: + model_string = f"openai/{model_id}" + api_base = ensure_v1(base_url) + if api_base: + kwargs["api_base"] = api_base + else: + provider_key = (native_provider or "").upper() + prefix = NATIVE_PROVIDER_PREFIX.get(provider_key, provider_key.lower()) + if prefix == "custom": + custom_provider = extra.get("custom_provider") or native_provider + model_string = f"{custom_provider}/{model_id}" if custom_provider else model_id + else: + model_string = f"{prefix}/{model_id}" + + api_base = resolve_api_base( + provider=provider_key, + provider_prefix=prefix, + config_api_base=base_url, + ) + if api_base: + kwargs["api_base"] = api_base + + if api_version := extra.get("api_version"): + kwargs["api_version"] = api_version + kwargs.update(extra.get("litellm_params", {})) + kwargs.update(extra.get("kwargs", {})) + return model_string, kwargs + + +def native_connection_from_config(config: Mapping[str, Any]) -> dict[str, Any]: + """Build an in-memory NATIVE connection mapping from a legacy/global config.""" + provider = str(config.get("provider") or config.get("custom_provider") or "CUSTOM") + extra: dict[str, Any] = { + "litellm_params": config.get("litellm_params") or {}, + } + if config.get("api_version"): + extra["api_version"] = config.get("api_version") + if config.get("custom_provider"): + extra["custom_provider"] = config.get("custom_provider") + return { + "protocol": PROTOCOL_NATIVE, + "native_provider": provider, + "base_url": config.get("api_base") or None, + "api_key": config.get("api_key") or None, + "extra": extra, + } + + +__all__ = [ + "NATIVE_PROVIDER_PREFIX", + "ensure_v1", + "native_connection_from_config", + "to_litellm", +] diff --git a/surfsense_backend/app/services/provider_capabilities.py b/surfsense_backend/app/services/provider_capabilities.py index f094c9954..9e1433214 100644 --- a/surfsense_backend/app/services/provider_capabilities.py +++ b/surfsense_backend/app/services/provider_capabilities.py @@ -46,6 +46,8 @@ from collections.abc import Iterable import litellm +from app.services.model_resolver import NATIVE_PROVIDER_PREFIX + logger = logging.getLogger(__name__) @@ -58,40 +60,7 @@ logger = logging.getLogger(__name__) # map there directly would re-introduce the # ``app.config -> ... -> deliverables/tools/generate_image -> # app.config`` cycle that prompted the move. -_PROVIDER_PREFIX_MAP: dict[str, str] = { - "OPENAI": "openai", - "ANTHROPIC": "anthropic", - "GROQ": "groq", - "COHERE": "cohere", - "GOOGLE": "gemini", - "OLLAMA": "ollama_chat", - "MISTRAL": "mistral", - "AZURE_OPENAI": "azure", - "OPENROUTER": "openrouter", - "XAI": "xai", - "BEDROCK": "bedrock", - "VERTEX_AI": "vertex_ai", - "TOGETHER_AI": "together_ai", - "FIREWORKS_AI": "fireworks_ai", - "DEEPSEEK": "openai", - "ALIBABA_QWEN": "openai", - "MOONSHOT": "openai", - "ZHIPU": "openai", - "GITHUB_MODELS": "github", - "REPLICATE": "replicate", - "PERPLEXITY": "perplexity", - "ANYSCALE": "anyscale", - "DEEPINFRA": "deepinfra", - "CEREBRAS": "cerebras", - "SAMBANOVA": "sambanova", - "AI21": "ai21", - "CLOUDFLARE": "cloudflare", - "DATABRICKS": "databricks", - "COMETAPI": "cometapi", - "HUGGINGFACE": "huggingface", - "MINIMAX": "openai", - "CUSTOM": "custom", -} +_PROVIDER_PREFIX_MAP = NATIVE_PROVIDER_PREFIX def _candidate_model_strings( From 62ff97c830f167c3810a7f50643f6cb779f98aad Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Wed, 10 Jun 2026 21:48:23 +0530 Subject: [PATCH 003/212] refactor(llm): route calls through resolved models --- .../app/services/billable_calls.py | 4 +- .../app/services/image_gen_router_service.py | 53 +-- .../app/services/llm_router_service.py | 88 +--- surfsense_backend/app/services/llm_service.py | 434 +++++++----------- .../app/services/vision_llm_router_service.py | 53 +-- 5 files changed, 209 insertions(+), 423 deletions(-) diff --git a/surfsense_backend/app/services/billable_calls.py b/surfsense_backend/app/services/billable_calls.py index 92ccd6a78..356195f6a 100644 --- a/surfsense_backend/app/services/billable_calls.py +++ b/surfsense_backend/app/services/billable_calls.py @@ -450,10 +450,10 @@ async def _resolve_agent_billing_for_search_space( thread_id: int | None = None, ) -> tuple[UUID, str, str]: """Resolve ``(owner_user_id, billing_tier, base_model)`` for the search-space - agent LLM. + chat model. Used by Celery tasks (podcast generation, video presentation) to bill the - search-space owner's premium credit pool when the agent LLM is premium. + search-space owner's premium credit pool when the chat model is premium. Resolution rules mirror chat at ``stream_new_chat.py:2294-2351``: diff --git a/surfsense_backend/app/services/image_gen_router_service.py b/surfsense_backend/app/services/image_gen_router_service.py index b4de2a0bf..0b03f5c6d 100644 --- a/surfsense_backend/app/services/image_gen_router_service.py +++ b/surfsense_backend/app/services/image_gen_router_service.py @@ -20,7 +20,11 @@ from typing import Any from litellm import Router from litellm.utils import ImageResponse -from app.services.provider_api_base import resolve_api_base +from app.services.model_resolver import ( + NATIVE_PROVIDER_PREFIX, + native_connection_from_config, + to_litellm, +) logger = logging.getLogger(__name__) @@ -30,17 +34,7 @@ IMAGE_GEN_AUTO_MODE_ID = 0 # Provider mapping for LiteLLM model string construction. # Only includes providers that support image generation. # See: https://docs.litellm.ai/docs/image_generation#supported-providers -IMAGE_GEN_PROVIDER_MAP = { - "OPENAI": "openai", - "AZURE_OPENAI": "azure", - "GOOGLE": "gemini", # Google AI Studio - "VERTEX_AI": "vertex_ai", - "BEDROCK": "bedrock", # AWS Bedrock - "RECRAFT": "recraft", - "OPENROUTER": "openrouter", - "XINFERENCE": "xinference", - "NSCALE": "nscale", -} +IMAGE_GEN_PROVIDER_MAP = NATIVE_PROVIDER_PREFIX class ImageGenRouterService: @@ -153,38 +147,11 @@ class ImageGenRouterService: if not config.get("model_name") or not config.get("api_key"): return None - # Build model string - provider = config.get("provider", "").upper() - if config.get("custom_provider"): - provider_prefix = config["custom_provider"] - else: - provider_prefix = IMAGE_GEN_PROVIDER_MAP.get(provider, provider.lower()) - model_string = f"{provider_prefix}/{config['model_name']}" - - # Build litellm params - litellm_params: dict[str, Any] = { - "model": model_string, - "api_key": config.get("api_key"), - } - - # Resolve ``api_base`` so deployments don't silently inherit - # ``AZURE_OPENAI_ENDPOINT`` / ``OPENAI_API_BASE`` and 404 against - # the wrong provider (see ``provider_api_base`` docstring). - api_base = resolve_api_base( - provider=provider, - provider_prefix=provider_prefix, - config_api_base=config.get("api_base"), + model_string, resolved_kwargs = to_litellm( + native_connection_from_config(config), + config["model_name"], ) - if api_base: - litellm_params["api_base"] = api_base - - # Add api_version (required for Azure) - if config.get("api_version"): - litellm_params["api_version"] = config["api_version"] - - # Add any additional litellm parameters - if config.get("litellm_params"): - litellm_params.update(config["litellm_params"]) + litellm_params: dict[str, Any] = {"model": model_string, **resolved_kwargs} # All configs use same alias "auto" for unified routing deployment: dict[str, Any] = { diff --git a/surfsense_backend/app/services/llm_router_service.py b/surfsense_backend/app/services/llm_router_service.py index d220aa346..69feb30eb 100644 --- a/surfsense_backend/app/services/llm_router_service.py +++ b/surfsense_backend/app/services/llm_router_service.py @@ -30,6 +30,11 @@ from litellm.exceptions import ( ) from pydantic import Field +from app.services.model_resolver import ( + NATIVE_PROVIDER_PREFIX, + native_connection_from_config, + to_litellm, +) from app.utils.perf import get_perf_logger litellm.json_logs = False @@ -96,52 +101,8 @@ def _sanitize_content(content: Any) -> Any: # Special ID for Auto mode - uses router for load balancing AUTO_MODE_ID = 0 -# Provider mapping for LiteLLM model string construction -PROVIDER_MAP = { - "OPENAI": "openai", - "ANTHROPIC": "anthropic", - "GROQ": "groq", - "COHERE": "cohere", - "GOOGLE": "gemini", - "OLLAMA": "ollama_chat", - "MISTRAL": "mistral", - "AZURE_OPENAI": "azure", - "OPENROUTER": "openrouter", - "COMETAPI": "cometapi", - "XAI": "xai", - "BEDROCK": "bedrock", - "AWS_BEDROCK": "bedrock", # Legacy support - "VERTEX_AI": "vertex_ai", - "TOGETHER_AI": "together_ai", - "FIREWORKS_AI": "fireworks_ai", - "REPLICATE": "replicate", - "PERPLEXITY": "perplexity", - "ANYSCALE": "anyscale", - "DEEPINFRA": "deepinfra", - "CEREBRAS": "cerebras", - "SAMBANOVA": "sambanova", - "AI21": "ai21", - "CLOUDFLARE": "cloudflare", - "DATABRICKS": "databricks", - "DEEPSEEK": "openai", - "ALIBABA_QWEN": "openai", - "MOONSHOT": "openai", - "ZHIPU": "openai", - "GITHUB_MODELS": "github", - "HUGGINGFACE": "huggingface", - "MINIMAX": "openai", - "CUSTOM": "custom", -} - - -# ``PROVIDER_DEFAULT_API_BASE`` and ``PROVIDER_KEY_DEFAULT_API_BASE`` were -# hoisted to ``app.services.provider_api_base`` so vision and image-gen -# call sites can share the exact same defense (OpenRouter / Groq / etc. -# 404-ing against an inherited Azure endpoint). Re-exported here for -# backward compatibility with any external import. -from app.services.provider_api_base import ( # noqa: E402 - resolve_api_base, -) +# Historical export kept for callers that still import PROVIDER_MAP. +PROVIDER_MAP = NATIVE_PROVIDER_PREFIX class LLMRouterService: @@ -420,38 +381,11 @@ class LLMRouterService: if not config.get("model_name") or not config.get("api_key"): return None - # Build model string - provider = config.get("provider", "").upper() - if config.get("custom_provider"): - provider_prefix = config["custom_provider"] - model_string = f"{provider_prefix}/{config['model_name']}" - else: - provider_prefix = PROVIDER_MAP.get(provider, provider.lower()) - model_string = f"{provider_prefix}/{config['model_name']}" - - # Build litellm params - litellm_params = { - "model": model_string, - "api_key": config.get("api_key"), - } - - # Resolve ``api_base``. Config value wins; otherwise apply a - # provider-aware default so the deployment does not silently - # inherit unrelated env vars (e.g. ``AZURE_API_BASE``) and route - # requests to the wrong endpoint. See ``provider_api_base`` - # docstring for the motivating bug (OpenRouter models 404-ing - # against an Azure endpoint). - api_base = resolve_api_base( - provider=provider, - provider_prefix=provider_prefix, - config_api_base=config.get("api_base"), + model_string, resolved_kwargs = to_litellm( + native_connection_from_config(config), + config["model_name"], ) - if api_base: - litellm_params["api_base"] = api_base - - # Add any additional litellm parameters - if config.get("litellm_params"): - litellm_params.update(config["litellm_params"]) + litellm_params = {"model": model_string, **resolved_kwargs} # Extract rate limits if provided deployment = { diff --git a/surfsense_backend/app/services/llm_service.py b/surfsense_backend/app/services/llm_service.py index 7061a826f..75451d01f 100644 --- a/surfsense_backend/app/services/llm_service.py +++ b/surfsense_backend/app/services/llm_service.py @@ -6,9 +6,10 @@ from langchain_core.messages import HumanMessage from langchain_litellm import ChatLiteLLM from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.future import select +from sqlalchemy.orm import selectinload from app.config import config -from app.db import NewLLMConfig, SearchSpace +from app.db import Model, SearchSpace from app.services.llm_router_service import ( AUTO_MODE_ID, ChatLiteLLMRouter, @@ -16,7 +17,7 @@ from app.services.llm_router_service import ( get_auto_mode_llm, is_auto_mode, ) -from app.services.provider_api_base import resolve_api_base +from app.services.model_resolver import native_connection_from_config, to_litellm from app.services.token_tracking_service import token_tracker # Configure litellm to automatically drop unsupported parameters @@ -66,6 +67,29 @@ def _is_interactive_auth_provider( return False +def _legacy_config_connection( + *, + provider: str, + model_name: str, + api_key: str | None, + api_base: str | None, + custom_provider: str | None = None, + litellm_params: dict | None = None, + api_version: str | None = None, +) -> tuple[str, dict]: + cfg = { + "provider": provider, + "model_name": model_name, + "api_key": api_key, + "api_base": api_base, + "custom_provider": custom_provider, + "api_version": api_version, + "litellm_params": litellm_params or {}, + } + conn = native_connection_from_config(cfg) + return to_litellm(conn, model_name) + + class LLMRole: AGENT = "agent" # For agent/chat operations @@ -102,6 +126,60 @@ def get_global_llm_config(llm_config_id: int) -> dict | None: return None +def get_global_model(model_id: int) -> dict | None: + return next((m for m in config.GLOBAL_MODELS if m.get("id") == model_id), None) + + +def get_global_connection(connection_id: int) -> dict | None: + return next( + (c for c in config.GLOBAL_CONNECTIONS if c.get("id") == connection_id), + None, + ) + + +def _has_capability(model: dict | Model, capability: str) -> bool: + caps = ( + model.get("capabilities", {}) + if isinstance(model, dict) + else model.capabilities or {} + ) + return bool(caps.get(capability)) + + +def _chat_litellm_from_resolved( + *, + conn: dict | object, + model_id: str, + disable_streaming: bool = False, +) -> tuple[str, dict]: + model_string, resolved_kwargs = to_litellm(conn, model_id) + litellm_kwargs = {"model": model_string, **resolved_kwargs} + if disable_streaming: + litellm_kwargs["disable_streaming"] = True + return model_string, litellm_kwargs + + +async def _get_db_model( + session: AsyncSession, + model_id: int, + search_space: SearchSpace, +) -> Model | None: + result = await session.execute( + select(Model) + .options(selectinload(Model.connection)) + .where(Model.id == model_id, Model.enabled.is_(True)) + ) + model = result.scalars().first() + if not model or not model.connection or not model.connection.enabled: + return None + conn = model.connection + if conn.search_space_id and conn.search_space_id != search_space.id: + return None + if conn.user_id and conn.user_id != search_space.user_id: + return None + return model + + async def validate_llm_config( provider: str, model_name: str, @@ -146,62 +224,15 @@ async def validate_llm_config( return False, msg try: - # Build the model string for litellm - if custom_provider: - model_string = f"{custom_provider}/{model_name}" - else: - # Map provider enum to litellm format - provider_map = { - "OPENAI": "openai", - "ANTHROPIC": "anthropic", - "GROQ": "groq", - "COHERE": "cohere", - "GOOGLE": "gemini", - "OLLAMA": "ollama_chat", - "MISTRAL": "mistral", - "AZURE_OPENAI": "azure", - "OPENROUTER": "openrouter", - "COMETAPI": "cometapi", - "XAI": "xai", - "BEDROCK": "bedrock", - "AWS_BEDROCK": "bedrock", # Legacy support (backward compatibility) - "VERTEX_AI": "vertex_ai", - "TOGETHER_AI": "together_ai", - "FIREWORKS_AI": "fireworks_ai", - "REPLICATE": "replicate", - "PERPLEXITY": "perplexity", - "ANYSCALE": "anyscale", - "DEEPINFRA": "deepinfra", - "CEREBRAS": "cerebras", - "SAMBANOVA": "sambanova", - "AI21": "ai21", - "CLOUDFLARE": "cloudflare", - "DATABRICKS": "databricks", - # Chinese LLM providers - "DEEPSEEK": "openai", - "ALIBABA_QWEN": "openai", - "MOONSHOT": "openai", - "ZHIPU": "openai", # GLM needs special handling - "MINIMAX": "openai", - "GITHUB_MODELS": "github", - } - provider_prefix = provider_map.get(provider, provider.lower()) - model_string = f"{provider_prefix}/{model_name}" - - # Create ChatLiteLLM instance - litellm_kwargs = { - "model": model_string, - "api_key": api_key, - "timeout": 30, # Set a timeout for validation - } - - # Add optional parameters - if api_base: - litellm_kwargs["api_base"] = api_base - - # Add any additional litellm parameters - if litellm_params: - litellm_kwargs.update(litellm_params) + model_string, resolved_kwargs = _legacy_config_connection( + provider=provider, + model_name=model_name, + api_key=api_key, + api_base=api_base, + custom_provider=custom_provider, + litellm_params=litellm_params, + ) + litellm_kwargs = {"model": model_string, **resolved_kwargs, "timeout": 30} from app.agents.chat.runtime.llm_config import ( SanitizedChatLiteLLM, @@ -283,9 +314,9 @@ async def get_search_space_llm_instance( logger.error(f"Search space {search_space_id} not found") return None - # Get the appropriate LLM config ID based on role + # Get the appropriate model binding ID based on role if role == LLMRole.AGENT: - llm_config_id = search_space.agent_llm_id + llm_config_id = search_space.chat_model_id else: logger.error(f"Invalid LLM role: {role}") return None @@ -312,70 +343,26 @@ async def get_search_space_llm_instance( logger.error(f"Failed to create ChatLiteLLMRouter: {e}") return None - # Check if this is a global config (negative ID) + # Check if this is a global virtual model (negative ID) if llm_config_id < 0: - global_config = get_global_llm_config(llm_config_id) - if not global_config: - logger.error(f"Global LLM config {llm_config_id} not found") + global_model = get_global_model(llm_config_id) + if not global_model or not _has_capability(global_model, "chat"): + logger.error(f"Global chat model {llm_config_id} not found") + return None + global_connection = get_global_connection(global_model["connection_id"]) + if not global_connection: + logger.error( + "Global connection %s not found for model %s", + global_model["connection_id"], + llm_config_id, + ) return None - # Build model string for global config - if global_config.get("custom_provider"): - model_string = ( - f"{global_config['custom_provider']}/{global_config['model_name']}" - ) - else: - provider_map = { - "OPENAI": "openai", - "ANTHROPIC": "anthropic", - "GROQ": "groq", - "COHERE": "cohere", - "GOOGLE": "gemini", - "OLLAMA": "ollama_chat", - "MISTRAL": "mistral", - "AZURE_OPENAI": "azure", - "OPENROUTER": "openrouter", - "COMETAPI": "cometapi", - "XAI": "xai", - "BEDROCK": "bedrock", - "AWS_BEDROCK": "bedrock", - "VERTEX_AI": "vertex_ai", - "TOGETHER_AI": "together_ai", - "FIREWORKS_AI": "fireworks_ai", - "REPLICATE": "replicate", - "PERPLEXITY": "perplexity", - "ANYSCALE": "anyscale", - "DEEPINFRA": "deepinfra", - "CEREBRAS": "cerebras", - "SAMBANOVA": "sambanova", - "AI21": "ai21", - "CLOUDFLARE": "cloudflare", - "DATABRICKS": "databricks", - "DEEPSEEK": "openai", - "ALIBABA_QWEN": "openai", - "MOONSHOT": "openai", - "ZHIPU": "openai", - "MINIMAX": "openai", - } - provider_prefix = provider_map.get( - global_config["provider"], global_config["provider"].lower() - ) - model_string = f"{provider_prefix}/{global_config['model_name']}" - - # Create ChatLiteLLM instance from global config - litellm_kwargs = { - "model": model_string, - "api_key": global_config["api_key"], - } - - if global_config.get("api_base"): - litellm_kwargs["api_base"] = global_config["api_base"] - - if global_config.get("litellm_params"): - litellm_kwargs.update(global_config["litellm_params"]) - - if disable_streaming: - litellm_kwargs["disable_streaming"] = True + _, litellm_kwargs = _chat_litellm_from_resolved( + conn=global_connection, + model_id=global_model["model_id"], + disable_streaming=disable_streaming, + ) from app.agents.chat.runtime.llm_config import ( SanitizedChatLiteLLM, @@ -383,80 +370,18 @@ async def get_search_space_llm_instance( return SanitizedChatLiteLLM(**litellm_kwargs) - # Get the LLM configuration from database (NewLLMConfig) - result = await session.execute( - select(NewLLMConfig).where( - NewLLMConfig.id == llm_config_id, - NewLLMConfig.search_space_id == search_space_id, - ) - ) - llm_config = result.scalars().first() - - if not llm_config: + model = await _get_db_model(session, llm_config_id, search_space) + if not model or not _has_capability(model, "chat"): logger.error( - f"LLM config {llm_config_id} not found in search space {search_space_id}" + f"Chat model {llm_config_id} not found in search space {search_space_id}" ) return None - # Build the model string for litellm - if llm_config.custom_provider: - model_string = f"{llm_config.custom_provider}/{llm_config.model_name}" - else: - # Map provider enum to litellm format - provider_map = { - "OPENAI": "openai", - "ANTHROPIC": "anthropic", - "GROQ": "groq", - "COHERE": "cohere", - "GOOGLE": "gemini", - "OLLAMA": "ollama_chat", - "MISTRAL": "mistral", - "AZURE_OPENAI": "azure", - "OPENROUTER": "openrouter", - "COMETAPI": "cometapi", - "XAI": "xai", - "BEDROCK": "bedrock", - "AWS_BEDROCK": "bedrock", - "VERTEX_AI": "vertex_ai", - "TOGETHER_AI": "together_ai", - "FIREWORKS_AI": "fireworks_ai", - "REPLICATE": "replicate", - "PERPLEXITY": "perplexity", - "ANYSCALE": "anyscale", - "DEEPINFRA": "deepinfra", - "CEREBRAS": "cerebras", - "SAMBANOVA": "sambanova", - "AI21": "ai21", - "CLOUDFLARE": "cloudflare", - "DATABRICKS": "databricks", - "DEEPSEEK": "openai", - "ALIBABA_QWEN": "openai", - "MOONSHOT": "openai", - "ZHIPU": "openai", - "MINIMAX": "openai", - "GITHUB_MODELS": "github", - } - provider_prefix = provider_map.get( - llm_config.provider.value, llm_config.provider.value.lower() - ) - model_string = f"{provider_prefix}/{llm_config.model_name}" - - # Create ChatLiteLLM instance - litellm_kwargs = { - "model": model_string, - "api_key": llm_config.api_key, - } - - # Add optional parameters - if llm_config.api_base: - litellm_kwargs["api_base"] = llm_config.api_base - - # Add any additional litellm parameters - if llm_config.litellm_params: - litellm_kwargs.update(llm_config.litellm_params) - - if disable_streaming: - litellm_kwargs["disable_streaming"] = True + _, litellm_kwargs = _chat_litellm_from_resolved( + conn=model.connection, + model_id=model.model_id, + disable_streaming=disable_streaming, + ) from app.agents.chat.runtime.llm_config import ( SanitizedChatLiteLLM, @@ -474,7 +399,7 @@ async def get_search_space_llm_instance( async def get_agent_llm( session: AsyncSession, search_space_id: int, disable_streaming: bool = False ) -> ChatLiteLLM | ChatLiteLLMRouter | None: - """Get the search space's agent LLM instance for chat operations.""" + """Get the search space's chat model instance.""" return await get_search_space_llm_instance( session, search_space_id, @@ -488,22 +413,19 @@ async def get_vision_llm( ) -> ChatLiteLLM | ChatLiteLLMRouter | None: """Get the search space's vision LLM instance for screenshot analysis. - Resolves from the dedicated VisionLLMConfig system: + Resolves from the new connection/model role bindings: - Auto mode (ID 0): VisionLLMRouterService - - Global (negative ID): YAML configs - - DB (positive ID): VisionLLMConfig table + - Global (negative ID): virtual GLOBAL models from YAML + - DB (positive ID): Model + Connection tables Premium global configs are wrapped in :class:`QuotaCheckedVisionLLM` so each ``ainvoke`` debits the search-space owner's premium credit pool. User-owned BYOK configs and free global configs are returned unwrapped — they don't consume premium credit (issue M). """ - from app.db import VisionLLMConfig from app.services.quota_checked_vision_llm import QuotaCheckedVisionLLM from app.services.vision_llm_router_service import ( - VISION_PROVIDER_MAP, VisionLLMRouterService, - get_global_vision_llm_config, is_vision_auto_mode, ) @@ -516,13 +438,43 @@ async def get_vision_llm( logger.error(f"Search space {search_space_id} not found") return None - config_id = search_space.vision_llm_config_id + owner_user_id = search_space.user_id + + # Prefer the selected chat model when it is vision-capable. + chat_model_id = search_space.chat_model_id + if chat_model_id and chat_model_id != AUTO_MODE_ID: + if chat_model_id < 0: + chat_model = get_global_model(chat_model_id) + if chat_model and _has_capability(chat_model, "vision"): + global_connection = get_global_connection(chat_model["connection_id"]) + if global_connection: + model_string, litellm_kwargs = _chat_litellm_from_resolved( + conn=global_connection, + model_id=chat_model["model_id"], + ) + from app.agents.chat.runtime.llm_config import ( + SanitizedChatLiteLLM, + ) + + return SanitizedChatLiteLLM(**litellm_kwargs) + else: + chat_model = await _get_db_model(session, chat_model_id, search_space) + if chat_model and _has_capability(chat_model, "vision"): + _, litellm_kwargs = _chat_litellm_from_resolved( + conn=chat_model.connection, + model_id=chat_model.model_id, + ) + from app.agents.chat.runtime.llm_config import ( + SanitizedChatLiteLLM, + ) + + return SanitizedChatLiteLLM(**litellm_kwargs) + + config_id = search_space.vision_model_id if config_id is None: logger.error(f"No vision LLM configured for search space {search_space_id}") return None - owner_user_id = search_space.user_id - if is_vision_auto_mode(config_id): if not VisionLLMRouterService.is_initialized(): logger.error( @@ -546,34 +498,24 @@ async def get_vision_llm( return None if config_id < 0: - global_cfg = get_global_vision_llm_config(config_id) - if not global_cfg: - logger.error(f"Global vision LLM config {config_id} not found") + global_model = get_global_model(config_id) + if not global_model or not _has_capability(global_model, "vision"): + logger.error(f"Global vision model {config_id} not found") return None - if global_cfg.get("custom_provider"): - provider_prefix = global_cfg["custom_provider"] - model_string = f"{provider_prefix}/{global_cfg['model_name']}" - else: - provider_prefix = VISION_PROVIDER_MAP.get( - global_cfg["provider"].upper(), - global_cfg["provider"].lower(), + global_connection = get_global_connection(global_model["connection_id"]) + if not global_connection: + logger.error( + "Global connection %s not found for model %s", + global_model["connection_id"], + config_id, ) - model_string = f"{provider_prefix}/{global_cfg['model_name']}" + return None - litellm_kwargs = { - "model": model_string, - "api_key": global_cfg["api_key"], - } - api_base = resolve_api_base( - provider=global_cfg.get("provider"), - provider_prefix=provider_prefix, - config_api_base=global_cfg.get("api_base"), + model_string, litellm_kwargs = _chat_litellm_from_resolved( + conn=global_connection, + model_id=global_model["model_id"], ) - if api_base: - litellm_kwargs["api_base"] = api_base - if global_cfg.get("litellm_params"): - litellm_kwargs.update(global_cfg["litellm_params"]) from app.agents.chat.runtime.llm_config import ( SanitizedChatLiteLLM, @@ -581,7 +523,7 @@ async def get_vision_llm( inner_llm = SanitizedChatLiteLLM(**litellm_kwargs) - billing_tier = str(global_cfg.get("billing_tier", "free")).lower() + billing_tier = str(global_model.get("billing_tier", "free")).lower() if billing_tier == "premium": return QuotaCheckedVisionLLM( inner_llm, @@ -589,47 +531,23 @@ async def get_vision_llm( search_space_id=search_space_id, billing_tier=billing_tier, base_model=model_string, - quota_reserve_tokens=global_cfg.get("quota_reserve_tokens"), + quota_reserve_tokens=global_model.get("catalog", {}).get( + "quota_reserve_tokens" + ), ) return inner_llm - # User-owned (positive ID) BYOK configs — always free. - result = await session.execute( - select(VisionLLMConfig).where( - VisionLLMConfig.id == config_id, - VisionLLMConfig.search_space_id == search_space_id, - ) - ) - vision_cfg = result.scalars().first() - if not vision_cfg: + model = await _get_db_model(session, config_id, search_space) + if not model or not _has_capability(model, "vision"): logger.error( - f"Vision LLM config {config_id} not found in search space {search_space_id}" + f"Vision model {config_id} not found in search space {search_space_id}" ) return None - if vision_cfg.custom_provider: - provider_prefix = vision_cfg.custom_provider - model_string = f"{provider_prefix}/{vision_cfg.model_name}" - else: - provider_prefix = VISION_PROVIDER_MAP.get( - vision_cfg.provider.value.upper(), - vision_cfg.provider.value.lower(), - ) - model_string = f"{provider_prefix}/{vision_cfg.model_name}" - - litellm_kwargs = { - "model": model_string, - "api_key": vision_cfg.api_key, - } - api_base = resolve_api_base( - provider=vision_cfg.provider.value, - provider_prefix=provider_prefix, - config_api_base=vision_cfg.api_base, + _, litellm_kwargs = _chat_litellm_from_resolved( + conn=model.connection, + model_id=model.model_id, ) - if api_base: - litellm_kwargs["api_base"] = api_base - if vision_cfg.litellm_params: - litellm_kwargs.update(vision_cfg.litellm_params) from app.agents.chat.runtime.llm_config import ( SanitizedChatLiteLLM, diff --git a/surfsense_backend/app/services/vision_llm_router_service.py b/surfsense_backend/app/services/vision_llm_router_service.py index ed5de921c..0c7182ecf 100644 --- a/surfsense_backend/app/services/vision_llm_router_service.py +++ b/surfsense_backend/app/services/vision_llm_router_service.py @@ -3,29 +3,17 @@ from typing import Any from litellm import Router -from app.services.provider_api_base import resolve_api_base +from app.services.model_resolver import ( + NATIVE_PROVIDER_PREFIX, + native_connection_from_config, + to_litellm, +) logger = logging.getLogger(__name__) VISION_AUTO_MODE_ID = 0 -VISION_PROVIDER_MAP = { - "OPENAI": "openai", - "ANTHROPIC": "anthropic", - "GOOGLE": "gemini", - "AZURE_OPENAI": "azure", - "VERTEX_AI": "vertex_ai", - "BEDROCK": "bedrock", - "XAI": "xai", - "OPENROUTER": "openrouter", - "OLLAMA": "ollama_chat", - "GROQ": "groq", - "TOGETHER_AI": "together_ai", - "FIREWORKS_AI": "fireworks_ai", - "DEEPSEEK": "openai", - "MISTRAL": "mistral", - "CUSTOM": "custom", -} +VISION_PROVIDER_MAP = NATIVE_PROVIDER_PREFIX class VisionLLMRouterService: @@ -110,32 +98,11 @@ class VisionLLMRouterService: if not config.get("model_name") or not config.get("api_key"): return None - provider = config.get("provider", "").upper() - if config.get("custom_provider"): - provider_prefix = config["custom_provider"] - model_string = f"{provider_prefix}/{config['model_name']}" - else: - provider_prefix = VISION_PROVIDER_MAP.get(provider, provider.lower()) - model_string = f"{provider_prefix}/{config['model_name']}" - - litellm_params: dict[str, Any] = { - "model": model_string, - "api_key": config.get("api_key"), - } - - api_base = resolve_api_base( - provider=provider, - provider_prefix=provider_prefix, - config_api_base=config.get("api_base"), + model_string, resolved_kwargs = to_litellm( + native_connection_from_config(config), + config["model_name"], ) - if api_base: - litellm_params["api_base"] = api_base - - if config.get("api_version"): - litellm_params["api_version"] = config["api_version"] - - if config.get("litellm_params"): - litellm_params.update(config["litellm_params"]) + litellm_params: dict[str, Any] = {"model": model_string, **resolved_kwargs} deployment: dict[str, Any] = { "model_name": "auto", From 077016d6e444c2bedb6d94cb4be6566771556815 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Wed, 10 Jun 2026 21:48:37 +0530 Subject: [PATCH 004/212] refactor(images): use model connections for image generation --- .../deliverables/tools/generate_image.py | 104 +++++------------- .../builtins/deliverables/tools/index.py | 4 +- .../app/routes/image_generation_routes.py | 96 ++++------------ 3 files changed, 52 insertions(+), 152 deletions(-) diff --git a/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/deliverables/tools/generate_image.py b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/deliverables/tools/generate_image.py index 7bb4a7c24..dd980c51c 100644 --- a/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/deliverables/tools/generate_image.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/deliverables/tools/generate_image.py @@ -10,13 +10,14 @@ from langgraph.types import Command from litellm import aimage_generation from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import selectinload from app.agents.chat.multi_agent_chat.shared.receipts.command import with_receipt from app.agents.chat.multi_agent_chat.shared.receipts.receipt import make_receipt from app.config import config from app.db import ( ImageGeneration, - ImageGenerationConfig, + Model, SearchSpace, shielded_async_session, ) @@ -25,37 +26,11 @@ from app.services.image_gen_router_service import ( ImageGenRouterService, is_image_gen_auto_mode, ) -from app.services.provider_api_base import resolve_api_base +from app.services.model_resolver import native_connection_from_config, to_litellm from app.utils.signed_image_urls import generate_image_token logger = logging.getLogger(__name__) -# Provider mapping (same as routes) -_PROVIDER_MAP = { - "OPENAI": "openai", - "AZURE_OPENAI": "azure", - "GOOGLE": "gemini", - "VERTEX_AI": "vertex_ai", - "BEDROCK": "bedrock", - "RECRAFT": "recraft", - "OPENROUTER": "openrouter", - "XINFERENCE": "xinference", - "NSCALE": "nscale", -} - - -def _resolve_provider_prefix(provider: str, custom_provider: str | None) -> str: - if custom_provider: - return custom_provider - return _PROVIDER_MAP.get(provider.upper(), provider.lower()) - - -def _build_model_string( - provider: str, model_name: str, custom_provider: str | None -) -> str: - return f"{_resolve_provider_prefix(provider, custom_provider)}/{model_name}" - - def _get_global_image_gen_config(config_id: int) -> dict | None: """Get a global image gen config by negative ID.""" for cfg in config.GLOBAL_IMAGE_GEN_CONFIGS: @@ -67,13 +42,13 @@ def _get_global_image_gen_config(config_id: int) -> dict | None: def create_generate_image_tool( search_space_id: int, db_session: AsyncSession, - image_generation_config_id_override: int | None = None, + image_gen_model_id_override: int | None = None, ): """Create ``generate_image`` with bound search space; DB work uses a per-call session. - ``image_generation_config_id_override``: when set (automations running on a - captured model), use this config id instead of reading the search space's - live ``image_generation_config_id``. + ``image_gen_model_id_override``: when set (automations running on a + captured model), use this model id instead of reading the search space's + live ``image_gen_model_id``. """ del db_session # tool uses a fresh per-call session instead @@ -118,11 +93,11 @@ def create_generate_image_tool( # task's session is shared across every tool; without isolation, # autoflushes from a concurrent writer poison this tool too. async with shielded_async_session() as session: - if image_generation_config_id_override is not None: + if image_gen_model_id_override is not None: # Automation run: use the captured image model, insulated from # later search-space changes. No search-space read needed. config_id = ( - image_generation_config_id_override or IMAGE_GEN_AUTO_MODE_ID + image_gen_model_id_override or IMAGE_GEN_AUTO_MODE_ID ) else: result = await session.execute( @@ -136,7 +111,7 @@ def create_generate_image_tool( ) config_id = ( - search_space.image_generation_config_id + search_space.image_gen_model_id or IMAGE_GEN_AUTO_MODE_ID ) @@ -162,58 +137,35 @@ def create_generate_image_tool( err = f"Image generation config {config_id} not found" return _failed({"error": err}, error=err) - provider_prefix = _resolve_provider_prefix( - cfg.get("provider", ""), cfg.get("custom_provider") + model_string, resolved_kwargs = to_litellm( + native_connection_from_config(cfg), + cfg["model_name"], ) - model_string = f"{provider_prefix}/{cfg['model_name']}" - gen_kwargs["api_key"] = cfg.get("api_key") - # Defense-in-depth: an empty ``api_base`` must not fall - # through to LiteLLM's global ``api_base`` (e.g. Azure). - api_base = resolve_api_base( - provider=cfg.get("provider"), - provider_prefix=provider_prefix, - config_api_base=cfg.get("api_base"), - ) - if api_base: - gen_kwargs["api_base"] = api_base - if cfg.get("api_version"): - gen_kwargs["api_version"] = cfg["api_version"] - if cfg.get("litellm_params"): - gen_kwargs.update(cfg["litellm_params"]) + gen_kwargs.update(resolved_kwargs) response = await aimage_generation( prompt=prompt, model=model_string, **gen_kwargs ) else: - # Positive ID = user-created ImageGenerationConfig + # Positive ID = Model + Connection cfg_result = await session.execute( - select(ImageGenerationConfig).filter( - ImageGenerationConfig.id == config_id - ) + select(Model) + .options(selectinload(Model.connection)) + .filter(Model.id == config_id, Model.enabled.is_(True)) ) - db_cfg = cfg_result.scalars().first() - if not db_cfg: - err = f"Image generation config {config_id} not found" + db_model = cfg_result.scalars().first() + if not db_model or not db_model.connection or not db_model.connection.enabled: + err = f"Image generation model {config_id} not found" + return _failed({"error": err}, error=err) + if not (db_model.capabilities or {}).get("image_gen"): + err = f"Model {config_id} is not image-generation capable" return _failed({"error": err}, error=err) - provider_prefix = _resolve_provider_prefix( - db_cfg.provider.value, db_cfg.custom_provider + model_string, resolved_kwargs = to_litellm( + db_model.connection, + db_model.model_id, ) - model_string = f"{provider_prefix}/{db_cfg.model_name}" - gen_kwargs["api_key"] = db_cfg.api_key - # Defense-in-depth: an empty ``api_base`` must not fall - # through to LiteLLM's global ``api_base`` (e.g. Azure). - api_base = resolve_api_base( - provider=db_cfg.provider.value, - provider_prefix=provider_prefix, - config_api_base=db_cfg.api_base, - ) - if api_base: - gen_kwargs["api_base"] = api_base - if db_cfg.api_version: - gen_kwargs["api_version"] = db_cfg.api_version - if db_cfg.litellm_params: - gen_kwargs.update(db_cfg.litellm_params) + gen_kwargs.update(resolved_kwargs) response = await aimage_generation( prompt=prompt, model=model_string, **gen_kwargs diff --git a/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/deliverables/tools/index.py b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/deliverables/tools/index.py index b968c1701..8de95f2df 100644 --- a/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/deliverables/tools/index.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/deliverables/tools/index.py @@ -51,8 +51,6 @@ def load_tools( create_generate_image_tool( search_space_id=d["search_space_id"], db_session=d["db_session"], - image_generation_config_id_override=d.get( - "image_generation_config_id_override" - ), + image_gen_model_id_override=d.get("image_gen_model_id_override"), ), ] diff --git a/surfsense_backend/app/routes/image_generation_routes.py b/surfsense_backend/app/routes/image_generation_routes.py index 018234ad5..0de368d57 100644 --- a/surfsense_backend/app/routes/image_generation_routes.py +++ b/surfsense_backend/app/routes/image_generation_routes.py @@ -16,11 +16,13 @@ from litellm import aimage_generation from sqlalchemy import select from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import selectinload from app.config import config from app.db import ( ImageGeneration, ImageGenerationConfig, + Model, Permission, SearchSpace, SearchSpaceMembership, @@ -46,7 +48,7 @@ from app.services.image_gen_router_service import ( ImageGenRouterService, is_image_gen_auto_mode, ) -from app.services.provider_api_base import resolve_api_base +from app.services.model_resolver import native_connection_from_config, to_litellm from app.users import current_active_user from app.utils.rbac import check_permission from app.utils.signed_image_urls import verify_image_token @@ -54,22 +56,6 @@ from app.utils.signed_image_urls import verify_image_token router = APIRouter() logger = logging.getLogger(__name__) -# Provider mapping for building litellm model strings. -# Only includes providers that support image generation. -# See: https://docs.litellm.ai/docs/image_generation#supported-providers -_PROVIDER_MAP = { - "OPENAI": "openai", - "AZURE_OPENAI": "azure", - "GOOGLE": "gemini", # Google AI Studio - "VERTEX_AI": "vertex_ai", - "BEDROCK": "bedrock", # AWS Bedrock - "RECRAFT": "recraft", - "OPENROUTER": "openrouter", - "XINFERENCE": "xinference", - "NSCALE": "nscale", -} - - def _get_global_image_gen_config(config_id: int) -> dict | None: """Get a global image generation configuration by ID (negative IDs).""" if config_id == IMAGE_GEN_AUTO_MODE_ID: @@ -88,20 +74,6 @@ def _get_global_image_gen_config(config_id: int) -> dict | None: return None -def _resolve_provider_prefix(provider: str, custom_provider: str | None) -> str: - """Resolve the LiteLLM provider prefix used in model strings.""" - if custom_provider: - return custom_provider - return _PROVIDER_MAP.get(provider.upper(), provider.lower()) - - -def _build_model_string( - provider: str, model_name: str, custom_provider: str | None -) -> str: - """Build a litellm model string from provider + model_name.""" - return f"{_resolve_provider_prefix(provider, custom_provider)}/{model_name}" - - async def _resolve_billing_for_image_gen( session: AsyncSession, config_id: int | None, @@ -124,7 +96,7 @@ async def _resolve_billing_for_image_gen( """ resolved_id = config_id if resolved_id is None: - resolved_id = search_space.image_generation_config_id or IMAGE_GEN_AUTO_MODE_ID + resolved_id = search_space.image_gen_model_id or IMAGE_GEN_AUTO_MODE_ID if is_image_gen_auto_mode(resolved_id): return ("free", "auto", DEFAULT_IMAGE_RESERVE_MICROS) @@ -132,11 +104,7 @@ async def _resolve_billing_for_image_gen( if resolved_id < 0: cfg = _get_global_image_gen_config(resolved_id) or {} billing_tier = str(cfg.get("billing_tier", "free")).lower() - base_model = _build_model_string( - cfg.get("provider", ""), - cfg.get("model_name", ""), - cfg.get("custom_provider"), - ) + base_model, _ = to_litellm(native_connection_from_config(cfg), cfg.get("model_name", "")) reserve_micros = int( cfg.get("quota_reserve_micros") or DEFAULT_IMAGE_RESERVE_MICROS ) @@ -161,7 +129,7 @@ async def _execute_image_generation( """ config_id = image_gen.image_generation_config_id if config_id is None: - config_id = search_space.image_generation_config_id or IMAGE_GEN_AUTO_MODE_ID + config_id = search_space.image_gen_model_id or IMAGE_GEN_AUTO_MODE_ID image_gen.image_generation_config_id = config_id # Build kwargs @@ -192,22 +160,11 @@ async def _execute_image_generation( if not cfg: raise ValueError(f"Global image generation config {config_id} not found") - provider_prefix = _resolve_provider_prefix( - cfg.get("provider", ""), cfg.get("custom_provider") + model_string, resolved_kwargs = to_litellm( + native_connection_from_config(cfg), + cfg["model_name"], ) - model_string = f"{provider_prefix}/{cfg['model_name']}" - gen_kwargs["api_key"] = cfg.get("api_key") - api_base = resolve_api_base( - provider=cfg.get("provider"), - provider_prefix=provider_prefix, - config_api_base=cfg.get("api_base"), - ) - if api_base: - gen_kwargs["api_base"] = api_base - if cfg.get("api_version"): - gen_kwargs["api_version"] = cfg["api_version"] - if cfg.get("litellm_params"): - gen_kwargs.update(cfg["litellm_params"]) + gen_kwargs.update(resolved_kwargs) # User model override if image_gen.model: @@ -217,30 +174,23 @@ async def _execute_image_generation( prompt=image_gen.prompt, model=model_string, **gen_kwargs ) else: - # Positive ID = DB ImageGenerationConfig + # Positive ID = Model + Connection result = await session.execute( - select(ImageGenerationConfig).filter(ImageGenerationConfig.id == config_id) + select(Model) + .options(selectinload(Model.connection)) + .filter(Model.id == config_id, Model.enabled.is_(True)) ) - db_cfg = result.scalars().first() - if not db_cfg: - raise ValueError(f"Image generation config {config_id} not found") + db_model = result.scalars().first() + if not db_model or not db_model.connection or not db_model.connection.enabled: + raise ValueError(f"Image generation model {config_id} not found") + if not (db_model.capabilities or {}).get("image_gen"): + raise ValueError(f"Model {config_id} is not image-generation capable") - provider_prefix = _resolve_provider_prefix( - db_cfg.provider.value, db_cfg.custom_provider + model_string, resolved_kwargs = to_litellm( + db_model.connection, + db_model.model_id, ) - model_string = f"{provider_prefix}/{db_cfg.model_name}" - gen_kwargs["api_key"] = db_cfg.api_key - api_base = resolve_api_base( - provider=db_cfg.provider.value, - provider_prefix=provider_prefix, - config_api_base=db_cfg.api_base, - ) - if api_base: - gen_kwargs["api_base"] = api_base - if db_cfg.api_version: - gen_kwargs["api_version"] = db_cfg.api_version - if db_cfg.litellm_params: - gen_kwargs.update(db_cfg.litellm_params) + gen_kwargs.update(resolved_kwargs) # User model override if image_gen.model: From 18606fe3880cd4c53ba912620bdd91ad6e0bccd8 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Wed, 10 Jun 2026 21:48:53 +0530 Subject: [PATCH 005/212] feat(automations): add model connection policy support --- .../builtin/agent_task/dependencies.py | 30 +++---- .../actions/builtin/agent_task/invoke.py | 8 +- .../app/automations/actions/types.py | 6 +- .../app/automations/runtime/executor.py | 8 +- .../schemas/definition/envelope.py | 10 +-- .../app/automations/services/automation.py | 12 +-- .../app/automations/services/model_policy.py | 87 +++++++------------ 7 files changed, 67 insertions(+), 94 deletions(-) diff --git a/surfsense_backend/app/automations/actions/builtin/agent_task/dependencies.py b/surfsense_backend/app/automations/actions/builtin/agent_task/dependencies.py index 4ef8c52bf..c9584ae2a 100644 --- a/surfsense_backend/app/automations/actions/builtin/agent_task/dependencies.py +++ b/surfsense_backend/app/automations/actions/builtin/agent_task/dependencies.py @@ -39,31 +39,31 @@ async def build_dependencies( *, session: AsyncSession, search_space_id: int, - agent_llm_id: int | None = None, - image_generation_config_id: int | None = None, - vision_llm_config_id: int | None = None, + chat_model_id: int | None = None, + image_gen_model_id: int | None = None, + vision_model_id: int | None = None, ) -> AgentDependencies: """Load the LLM bundle, connector service, and a per-invoke in-memory checkpointer. - Resolves the agent LLM from the automation's *captured* model snapshot - (``agent_llm_id``) so runs are insulated from later chat/search-space model + Resolves the chat model from the automation's *captured* model snapshot + (``chat_model_id``) so runs are insulated from later chat/search-space model changes. The model policy is enforced here as a runtime backstop: a captured model that is no longer billable (e.g. a premium global config was removed) fails the run clearly instead of silently consuming a free model. - When ``agent_llm_id`` is ``None`` (no captured snapshot — defensive fallback), - fall back to the live search space's ``agent_llm_id`` and validate that. + When ``chat_model_id`` is ``None`` (no captured snapshot — defensive fallback), + fall back to the live search space's ``chat_model_id`` and validate that. """ - if agent_llm_id is not None: + if chat_model_id is not None: try: assert_models_billable( - agent_llm_id=agent_llm_id, - image_generation_config_id=image_generation_config_id, - vision_llm_config_id=vision_llm_config_id, + chat_model_id=chat_model_id, + image_gen_model_id=image_gen_model_id, + vision_model_id=vision_model_id, ) except AutomationModelPolicyError as exc: raise DependencyError(str(exc)) from exc - resolved_agent_llm_id = agent_llm_id or 0 + resolved_chat_model_id = chat_model_id or 0 else: search_space = await session.get(SearchSpace, search_space_id) if search_space is None: @@ -72,15 +72,15 @@ async def build_dependencies( assert_automation_models_billable(search_space) except AutomationModelPolicyError as exc: raise DependencyError(str(exc)) from exc - resolved_agent_llm_id = search_space.agent_llm_id or 0 + resolved_chat_model_id = search_space.chat_model_id or 0 llm, agent_config, err = await load_llm_bundle( session, - config_id=resolved_agent_llm_id, + config_id=resolved_chat_model_id, search_space_id=search_space_id, ) if err is not None or llm is None: - raise DependencyError(err or "failed to load agent LLM config") + raise DependencyError(err or "failed to load chat model config") connector_service, firecrawl_api_key = await setup_connector_and_firecrawl( session, search_space_id=search_space_id diff --git a/surfsense_backend/app/automations/actions/builtin/agent_task/invoke.py b/surfsense_backend/app/automations/actions/builtin/agent_task/invoke.py index aa96e4f6e..c3a35930d 100644 --- a/surfsense_backend/app/automations/actions/builtin/agent_task/invoke.py +++ b/surfsense_backend/app/automations/actions/builtin/agent_task/invoke.py @@ -150,9 +150,9 @@ async def run_agent_task( deps = await build_dependencies( session=agent_session, search_space_id=ctx.search_space_id, - agent_llm_id=ctx.agent_llm_id, - image_generation_config_id=ctx.image_generation_config_id, - vision_llm_config_id=ctx.vision_llm_config_id, + chat_model_id=ctx.chat_model_id, + image_gen_model_id=ctx.image_gen_model_id, + vision_model_id=ctx.vision_model_id, ) agent = await create_multi_agent_chat_deep_agent( @@ -167,7 +167,7 @@ async def run_agent_task( firecrawl_api_key=deps.firecrawl_api_key, thread_visibility=ChatVisibility.PRIVATE, mentioned_document_ids=mentioned_document_ids, - image_generation_config_id=ctx.image_generation_config_id, + image_gen_model_id=ctx.image_gen_model_id, ) agent_query, runtime_context = await _resolve_mention_context( diff --git a/surfsense_backend/app/automations/actions/types.py b/surfsense_backend/app/automations/actions/types.py index 453721a43..3ee427512 100644 --- a/surfsense_backend/app/automations/actions/types.py +++ b/surfsense_backend/app/automations/actions/types.py @@ -23,9 +23,9 @@ class ActionContext: # Captured model snapshot from the automation definition (``definition.models``), # resolved per run instead of the live search space. ``None`` falls back to the # search space's current prefs (defensive; should not happen post-capture). - agent_llm_id: int | None = None - image_generation_config_id: int | None = None - vision_llm_config_id: int | None = None + chat_model_id: int | None = None + image_gen_model_id: int | None = None + vision_model_id: int | None = None ActionHandler = Callable[[dict[str, Any]], Awaitable[Any]] diff --git a/surfsense_backend/app/automations/runtime/executor.py b/surfsense_backend/app/automations/runtime/executor.py index da249d8e5..bcdab3940 100644 --- a/surfsense_backend/app/automations/runtime/executor.py +++ b/surfsense_backend/app/automations/runtime/executor.py @@ -132,9 +132,7 @@ def _build_action_ctx( step_id=step.step_id, search_space_id=automation.search_space_id, creator_user_id=automation.created_by_user_id, - agent_llm_id=models.agent_llm_id if models else None, - image_generation_config_id=( - models.image_generation_config_id if models else None - ), - vision_llm_config_id=models.vision_llm_config_id if models else None, + chat_model_id=models.chat_model_id if models else None, + image_gen_model_id=models.image_gen_model_id if models else None, + vision_model_id=models.vision_model_id if models else None, ) diff --git a/surfsense_backend/app/automations/schemas/definition/envelope.py b/surfsense_backend/app/automations/schemas/definition/envelope.py index 7ca55b1ce..787534d4a 100644 --- a/surfsense_backend/app/automations/schemas/definition/envelope.py +++ b/surfsense_backend/app/automations/schemas/definition/envelope.py @@ -14,16 +14,16 @@ from .trigger_spec import TriggerSpec class AutomationModels(BaseModel): """Captured model profile for an automation. - Snapshotted from the search space's preferences at create time so runs are - insulated from later chat/search-space model changes. Config-id conventions + Snapshotted from the search space's model roles at create time so runs are + insulated from later chat/search-space model changes. Model-id conventions match the shared scheme (``0`` Auto, ``< 0`` global, ``> 0`` BYOK). """ model_config = ConfigDict(extra="forbid") - agent_llm_id: int = 0 - image_generation_config_id: int = 0 - vision_llm_config_id: int = 0 + chat_model_id: int = 0 + image_gen_model_id: int = 0 + vision_model_id: int = 0 class AutomationDefinition(BaseModel): diff --git a/surfsense_backend/app/automations/services/automation.py b/surfsense_backend/app/automations/services/automation.py index 4227161e2..1d371c35d 100644 --- a/surfsense_backend/app/automations/services/automation.py +++ b/surfsense_backend/app/automations/services/automation.py @@ -57,9 +57,9 @@ class AutomationService: else: search_space = await self._assert_models_billable(payload.search_space_id) payload.definition.models = AutomationModels( - agent_llm_id=search_space.agent_llm_id or 0, - image_generation_config_id=search_space.image_generation_config_id or 0, - vision_llm_config_id=search_space.vision_llm_config_id or 0, + chat_model_id=search_space.chat_model_id or 0, + image_gen_model_id=search_space.image_gen_model_id or 0, + vision_model_id=search_space.vision_model_id or 0, ) automation = Automation( @@ -225,9 +225,9 @@ class AutomationService: """ try: assert_models_billable( - agent_llm_id=models.agent_llm_id, - image_generation_config_id=models.image_generation_config_id, - vision_llm_config_id=models.vision_llm_config_id, + chat_model_id=models.chat_model_id, + image_gen_model_id=models.image_gen_model_id, + vision_model_id=models.vision_model_id, ) except AutomationModelPolicyError as exc: raise HTTPException(status_code=422, detail=str(exc)) from exc diff --git a/surfsense_backend/app/automations/services/model_policy.py b/surfsense_backend/app/automations/services/model_policy.py index 7e3e46b61..b160fc78d 100644 --- a/surfsense_backend/app/automations/services/model_policy.py +++ b/surfsense_backend/app/automations/services/model_policy.py @@ -24,70 +24,45 @@ from typing import TYPE_CHECKING, Literal if TYPE_CHECKING: from app.db import SearchSpace -ModelKind = Literal["llm", "image", "vision"] +ModelKind = Literal["chat", "image", "vision"] _KIND_LABEL: dict[ModelKind, str] = { - "llm": "agent LLM", + "chat": "chat model", "image": "image generation model", "vision": "vision model", } -def _is_premium_global(kind: ModelKind, config_id: int) -> bool: - """Return True if a negative (global) config id is a premium tier model.""" +def _is_premium_global(model_id: int) -> bool: + """Return True if a negative (global) model id is a premium tier model.""" from app.config import config as app_config - cfg: dict | None = None - if kind == "llm": - from app.agents.chat.runtime.llm_config import ( - load_global_llm_config_by_id, - ) - - cfg = load_global_llm_config_by_id(config_id) - elif kind == "image": - cfg = next( - ( - c - for c in app_config.GLOBAL_IMAGE_GEN_CONFIGS - if c.get("id") == config_id - ), - None, - ) - else: # vision - cfg = next( - ( - c - for c in app_config.GLOBAL_VISION_LLM_CONFIGS - if c.get("id") == config_id - ), - None, - ) - - if not cfg: + model = next((m for m in app_config.GLOBAL_MODELS if m.get("id") == model_id), None) + if not model: return False - return str(cfg.get("billing_tier", "free")).lower() == "premium" + return str(model.get("billing_tier", "free")).lower() == "premium" -def _classify(kind: ModelKind, config_id: int | None) -> tuple[bool, str]: - """Classify a resolved config id as allowed or blocked. +def _classify(kind: ModelKind, model_id: int | None) -> tuple[bool, str]: + """Classify a resolved model id as allowed or blocked. Returns ``(allowed, reason)``; ``reason`` is empty when allowed. """ label = _KIND_LABEL[kind] - if config_id is None or config_id == 0: + if model_id is None or model_id == 0: return ( False, f"The {label} is set to Auto mode. Automations require an explicit " "premium model or your own (BYOK) model so every run is billable.", ) - if config_id > 0: - # Positive id → user-owned BYOK config. Always allowed. + if model_id > 0: + # Positive id -> user/search-space BYOK model. Always allowed. return True, "" - # Negative id → global config. Allowed only if premium. - if _is_premium_global(kind, config_id): + # Negative id -> global model. Allowed only if premium. + if _is_premium_global(model_id): return True, "" return ( @@ -99,27 +74,27 @@ def _classify(kind: ModelKind, config_id: int | None) -> tuple[bool, str]: def get_model_eligibility( *, - agent_llm_id: int | None, - image_generation_config_id: int | None, - vision_llm_config_id: int | None, + chat_model_id: int | None, + image_gen_model_id: int | None, + vision_model_id: int | None, ) -> dict: - """Return ``{"allowed": bool, "violations": [...]}`` for explicit config ids. + """Return ``{"allowed": bool, "violations": [...]}`` for explicit model ids. The ID-based core shared by both the search-space path (creation/eligibility) and the captured-snapshot path (runtime backstop). Each violation is ``{"kind", "config_id", "reason"}``. """ checks: list[tuple[ModelKind, int | None]] = [ - ("llm", agent_llm_id), - ("image", image_generation_config_id), - ("vision", vision_llm_config_id), + ("chat", chat_model_id), + ("image", image_gen_model_id), + ("vision", vision_model_id), ] violations: list[dict] = [] for kind, config_id in checks: allowed, reason = _classify(kind, config_id) if not allowed: - violations.append({"kind": kind, "config_id": config_id, "reason": reason}) + violations.append({"kind": kind, "model_id": config_id, "reason": reason}) return {"allowed": not violations, "violations": violations} @@ -131,9 +106,9 @@ def get_automation_model_eligibility(search_space: SearchSpace) -> dict: wrapper over :func:`get_model_eligibility`. """ return get_model_eligibility( - agent_llm_id=search_space.agent_llm_id, - image_generation_config_id=search_space.image_generation_config_id, - vision_llm_config_id=search_space.vision_llm_config_id, + chat_model_id=search_space.chat_model_id, + image_gen_model_id=search_space.image_gen_model_id, + vision_model_id=search_space.vision_model_id, ) @@ -150,9 +125,9 @@ class AutomationModelPolicyError(Exception): def assert_models_billable( *, - agent_llm_id: int | None, - image_generation_config_id: int | None, - vision_llm_config_id: int | None, + chat_model_id: int | None, + image_gen_model_id: int | None, + vision_model_id: int | None, ) -> None: """Raise :class:`AutomationModelPolicyError` if any explicit id is not billable. @@ -160,9 +135,9 @@ def assert_models_billable( captured model snapshot. """ result = get_model_eligibility( - agent_llm_id=agent_llm_id, - image_generation_config_id=image_generation_config_id, - vision_llm_config_id=vision_llm_config_id, + chat_model_id=chat_model_id, + image_gen_model_id=image_gen_model_id, + vision_model_id=vision_model_id, ) if not result["allowed"]: raise AutomationModelPolicyError(result["violations"]) From 32ab2b8713e6b67493ed6873e3e74f11e6013240 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Wed, 10 Jun 2026 21:49:07 +0530 Subject: [PATCH 006/212] feat(web): expose model policies in automations --- .../builder/automation-builder-form.tsx | 2 +- .../builder/automation-model-fields.tsx | 14 +-- .../automations/automations-mutation.atoms.ts | 6 +- .../tool-ui/automation/create-automation.tsx | 18 +-- .../contracts/types/automation.types.ts | 6 +- .../hooks/use-automation-eligible-models.ts | 117 +++++++----------- .../lib/automations/builder-schema.ts | 20 +-- 7 files changed, 77 insertions(+), 106 deletions(-) diff --git a/surfsense_web/app/dashboard/[search_space_id]/automations/components/builder/automation-builder-form.tsx b/surfsense_web/app/dashboard/[search_space_id]/automations/components/builder/automation-builder-form.tsx index 59967080f..a68e53a1c 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/automations/components/builder/automation-builder-form.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/automations/components/builder/automation-builder-form.tsx @@ -130,7 +130,7 @@ export function AutomationBuilderForm({ // data into state, so there's no flicker/loop and the user's pick is sticky. const resolvedModels = useMemo( () => ({ - agentLlmId: form.models.agentLlmId || eligibleModels.llm.defaultId || 0, + chatModelId: form.models.chatModelId || eligibleModels.llm.defaultId || 0, imageConfigId: form.models.imageConfigId || eligibleModels.image.defaultId || 0, visionConfigId: form.models.visionConfigId || eligibleModels.vision.defaultId || 0, }), diff --git a/surfsense_web/app/dashboard/[search_space_id]/automations/components/builder/automation-model-fields.tsx b/surfsense_web/app/dashboard/[search_space_id]/automations/components/builder/automation-model-fields.tsx index 2c4a0bf60..6dd42366b 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/automations/components/builder/automation-model-fields.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/automations/components/builder/automation-model-fields.tsx @@ -25,7 +25,7 @@ import { getProviderIcon } from "@/lib/provider-icons"; import { Field } from "./form-field"; export interface AutomationModelSelection { - agentLlmId: number; + chatModelId: number; imageConfigId: number; visionConfigId: number; } @@ -39,7 +39,7 @@ interface AutomationModelFieldsProps { } /** - * Three eligible-only model pickers (Agent LLM / Image / Vision) for the + * Three eligible-only model pickers (Chat / Image / Vision) for the * automation builder + chat approval card. Options come from * {@link useAutomationEligibleModels} (premium globals + BYOK only); selection * is validated + snapshotted onto `definition.models` at create time. @@ -51,18 +51,18 @@ export function AutomationModelFields({ errors, }: AutomationModelFieldsProps) { const { llm, image, vision, isLoading } = useAutomationEligibleModels(); - const rolesHref = `/dashboard/${searchSpaceId}/search-space-settings/roles`; + const rolesHref = `/dashboard/${searchSpaceId}/search-space-settings/models`; return (
onChange({ agentLlmId: id })} + error={errors?.chatModelId} + onChange={(id) => onChange({ chatModelId: id })} /> ({ task_count: variables.definition.plan.length, trigger_type: variables.triggers?.[0]?.type ?? "none", has_schedule: (variables.triggers?.length ?? 0) > 0, - agent_llm_id: variables.definition.models?.agent_llm_id, - image_generation_config_id: variables.definition.models?.image_generation_config_id, - vision_llm_config_id: variables.definition.models?.vision_llm_config_id, + chat_model_id: variables.definition.models?.chat_model_id, + image_gen_model_id: variables.definition.models?.image_gen_model_id, + vision_model_id: variables.definition.models?.vision_model_id, tags_count: variables.definition.metadata?.tags?.length, }); }, diff --git a/surfsense_web/components/tool-ui/automation/create-automation.tsx b/surfsense_web/components/tool-ui/automation/create-automation.tsx index 24e9d66bd..2a7d09f53 100644 --- a/surfsense_web/components/tool-ui/automation/create-automation.tsx +++ b/surfsense_web/components/tool-ui/automation/create-automation.tsx @@ -113,7 +113,7 @@ function ApprovalCard({ args, interruptData, onDecision }: ApprovalCardProps) { const searchSpaceId = useAtomValue(activeSearchSpaceIdAtom); const eligibleModels = useAutomationEligibleModels(); const [modelSelection, setModelSelection] = useState({ - agentLlmId: 0, + chatModelId: 0, imageConfigId: 0, visionConfigId: 0, }); @@ -121,7 +121,7 @@ function ApprovalCard({ args, interruptData, onDecision }: ApprovalCardProps) { // default. No effect seeds async hook data into state. const resolvedModels = useMemo( () => ({ - agentLlmId: modelSelection.agentLlmId || eligibleModels.llm.defaultId || 0, + chatModelId: modelSelection.chatModelId || eligibleModels.llm.defaultId || 0, imageConfigId: modelSelection.imageConfigId || eligibleModels.image.defaultId || 0, visionConfigId: modelSelection.visionConfigId || eligibleModels.vision.defaultId || 0, }), @@ -133,7 +133,7 @@ function ApprovalCard({ args, interruptData, onDecision }: ApprovalCardProps) { ] ); const modelsResolved = - resolvedModels.agentLlmId !== 0 && + resolvedModels.chatModelId !== 0 && resolvedModels.imageConfigId !== 0 && resolvedModels.visionConfigId !== 0; @@ -147,9 +147,9 @@ function ApprovalCard({ args, interruptData, onDecision }: ApprovalCardProps) { definition: { ...baseDefinition, models: { - agent_llm_id: resolvedModels.agentLlmId, - image_generation_config_id: resolvedModels.imageConfigId, - vision_llm_config_id: resolvedModels.visionConfigId, + chat_model_id: resolvedModels.chatModelId, + image_gen_model_id: resolvedModels.imageConfigId, + vision_model_id: resolvedModels.visionConfigId, }, }, }; @@ -162,9 +162,9 @@ function ApprovalCard({ args, interruptData, onDecision }: ApprovalCardProps) { trigger_type: (triggers[0] as { type?: string } | undefined)?.type ?? (triggers.length ? undefined : "none"), - agent_llm_id: resolvedModels.agentLlmId, - image_generation_config_id: resolvedModels.imageConfigId, - vision_llm_config_id: resolvedModels.visionConfigId, + chat_model_id: resolvedModels.chatModelId, + image_gen_model_id: resolvedModels.imageConfigId, + vision_model_id: resolvedModels.visionConfigId, }); onDecision({ type: "edit", diff --git a/surfsense_web/contracts/types/automation.types.ts b/surfsense_web/contracts/types/automation.types.ts index 45670d245..6331a663c 100644 --- a/surfsense_web/contracts/types/automation.types.ts +++ b/surfsense_web/contracts/types/automation.types.ts @@ -63,9 +63,9 @@ export type Inputs = z.infer; // Captured model snapshot (server-managed). Set at create time and preserved // across edits so runs are insulated from later chat/search-space model changes. export const automationModels = z.object({ - agent_llm_id: z.number().int().default(0), - image_generation_config_id: z.number().int().default(0), - vision_llm_config_id: z.number().int().default(0), + chat_model_id: z.number().int().default(0), + image_gen_model_id: z.number().int().default(0), + vision_model_id: z.number().int().default(0), }); export type AutomationModels = z.infer; diff --git a/surfsense_web/hooks/use-automation-eligible-models.ts b/surfsense_web/hooks/use-automation-eligible-models.ts index e74994221..e75235c56 100644 --- a/surfsense_web/hooks/use-automation-eligible-models.ts +++ b/surfsense_web/hooks/use-automation-eligible-models.ts @@ -3,18 +3,11 @@ import { useAtomValue } from "jotai"; import { useMemo } from "react"; import { - globalImageGenConfigsAtom, - imageGenConfigsAtom, -} from "@/atoms/image-gen-config/image-gen-config-query.atoms"; -import { - globalNewLLMConfigsAtom, - llmPreferencesAtom, - newLLMConfigsAtom, -} from "@/atoms/new-llm-config/new-llm-config-query.atoms"; -import { - globalVisionLLMConfigsAtom, - visionLLMConfigsAtom, -} from "@/atoms/vision-llm-config/vision-llm-config-query.atoms"; + globalModelConnectionsAtom, + modelConnectionsAtom, + modelRolesAtom, +} from "@/atoms/model-connections/model-connections-query.atoms"; +import type { ConnectionRead, ModelRead } from "@/contracts/types/model-connections.types"; /** * A single model the user may pick for an automation slot. @@ -44,48 +37,40 @@ export interface AutomationEligibleModels { isLoading: boolean; } -interface GlobalConfigLike { - id: number; - name: string; - model_name: string; - provider: string; - is_premium?: boolean; - is_auto_mode?: boolean; -} - -interface UserConfigLike { - id: number; - name: string; - model_name: string; - provider: string; -} - /** * Build the eligible option list for one model kind: premium globals - * (`is_premium === true`, never Auto mode) followed by all BYOK configs. + * followed by all BYOK/search-space models. */ function buildKind( - globals: GlobalConfigLike[] | undefined, - byok: UserConfigLike[] | undefined, + globals: ConnectionRead[] | undefined, + byok: ConnectionRead[] | undefined, + capability: "chat" | "image_gen" | "vision", prefId: number | null | undefined ): EligibleModelKind { - const premiumGlobals: EligibleModelOption[] = (globals ?? []) - .filter((c) => c.is_premium === true && !c.is_auto_mode) - .map((c) => ({ - id: c.id, - name: c.name, - modelName: c.model_name, - provider: c.provider, - isBYOK: false, - })); + const toOption = (connection: ConnectionRead, model: ModelRead, isBYOK: boolean) => ({ + id: model.id, + name: model.display_name || model.model_id, + modelName: model.model_id, + provider: connection.native_provider || connection.protocol, + isBYOK, + }); - const byokOptions: EligibleModelOption[] = (byok ?? []).map((c) => ({ - id: c.id, - name: c.name, - modelName: c.model_name, - provider: c.provider, - isBYOK: true, - })); + const premiumGlobals: EligibleModelOption[] = (globals ?? []).flatMap((connection) => + connection.models + .filter( + (model) => + model.enabled && + Boolean(model.capabilities?.[capability]) && + String(model.billing_tier ?? "").toLowerCase() === "premium" + ) + .map((model) => toOption(connection, model, false)) + ); + + const byokOptions: EligibleModelOption[] = (byok ?? []).flatMap((connection) => + connection.models + .filter((model) => model.enabled && Boolean(model.capabilities?.[capability])) + .map((model) => toOption(connection, model, true)) + ); const options = [...premiumGlobals, ...byokOptions]; const byId = new Map(options.map((o) => [o.id, o])); @@ -105,46 +90,32 @@ function buildKind( * (premium globals + user BYOK — never free globals or Auto mode), with a * default selection seeded from the search space's role preferences. * - * Everything is derived during render from the existing config query atoms; + * Everything is derived during render from the connection/model query atoms; * there are no effects, so option lists/maps keep stable references. */ export function useAutomationEligibleModels(): AutomationEligibleModels { - const { data: llmUserConfigs, isLoading: llmUserLoading } = useAtomValue(newLLMConfigsAtom); - const { data: llmGlobalConfigs, isLoading: llmGlobalLoading } = - useAtomValue(globalNewLLMConfigsAtom); - const { data: preferences, isLoading: prefsLoading } = useAtomValue(llmPreferencesAtom); - const { data: imageGlobalConfigs, isLoading: imageGlobalLoading } = - useAtomValue(globalImageGenConfigsAtom); - const { data: imageUserConfigs, isLoading: imageUserLoading } = useAtomValue(imageGenConfigsAtom); - const { data: visionGlobalConfigs, isLoading: visionGlobalLoading } = useAtomValue( - globalVisionLLMConfigsAtom + const { data: byokConnections, isLoading: byokLoading } = useAtomValue(modelConnectionsAtom); + const { data: globalConnections, isLoading: globalLoading } = useAtomValue( + globalModelConnectionsAtom ); - const { data: visionUserConfigs, isLoading: visionUserLoading } = - useAtomValue(visionLLMConfigsAtom); + const { data: roles, isLoading: rolesLoading } = useAtomValue(modelRolesAtom); const llm = useMemo( - () => buildKind(llmGlobalConfigs, llmUserConfigs, preferences?.agent_llm_id), - [llmGlobalConfigs, llmUserConfigs, preferences?.agent_llm_id] + () => buildKind(globalConnections, byokConnections, "chat", roles?.chat_model_id), + [globalConnections, byokConnections, roles?.chat_model_id] ); const image = useMemo( - () => buildKind(imageGlobalConfigs, imageUserConfigs, preferences?.image_generation_config_id), - [imageGlobalConfigs, imageUserConfigs, preferences?.image_generation_config_id] + () => buildKind(globalConnections, byokConnections, "image_gen", roles?.image_gen_model_id), + [globalConnections, byokConnections, roles?.image_gen_model_id] ); const vision = useMemo( - () => buildKind(visionGlobalConfigs, visionUserConfigs, preferences?.vision_llm_config_id), - [visionGlobalConfigs, visionUserConfigs, preferences?.vision_llm_config_id] + () => buildKind(globalConnections, byokConnections, "vision", roles?.vision_model_id), + [globalConnections, byokConnections, roles?.vision_model_id] ); - const isLoading = - llmUserLoading || - llmGlobalLoading || - prefsLoading || - imageGlobalLoading || - imageUserLoading || - visionGlobalLoading || - visionUserLoading; + const isLoading = byokLoading || globalLoading || rolesLoading; return useMemo(() => ({ llm, image, vision, isLoading }), [llm, image, vision, isLoading]); } diff --git a/surfsense_web/lib/automations/builder-schema.ts b/surfsense_web/lib/automations/builder-schema.ts index c2bd69209..5bb034bef 100644 --- a/surfsense_web/lib/automations/builder-schema.ts +++ b/surfsense_web/lib/automations/builder-schema.ts @@ -73,7 +73,7 @@ export type BuilderExecution = z.infer; * later chat/search-space model changes. */ export const builderModelsSchema = z.object({ - agentLlmId: z.number().int(), + chatModelId: z.number().int(), imageConfigId: z.number().int(), visionConfigId: z.number().int(), }); @@ -90,7 +90,7 @@ export const builderFormSchema = z.object({ tags: z.array(z.string()), /** Carried through from an edited definition so we don't drop it. */ goal: z.string().nullable(), - /** Selected agent/image/vision models (``0`` = use the eligible default). */ + /** Selected chat/image/vision models (``0`` = use the eligible default). */ models: builderModelsSchema, }); export type BuilderForm = z.infer; @@ -147,7 +147,7 @@ export function createEmptyForm(): BuilderForm { }, tags: [], goal: null, - models: { agentLlmId: 0, imageConfigId: 0, visionConfigId: 0 }, + models: { chatModelId: 0, imageConfigId: 0, visionConfigId: 0 }, }; } @@ -240,9 +240,9 @@ function buildDefinition(form: BuilderForm): AutomationDefinition { ...(hasResolvedModels(form.models) ? { models: { - agent_llm_id: form.models.agentLlmId, - image_generation_config_id: form.models.imageConfigId, - vision_llm_config_id: form.models.visionConfigId, + chat_model_id: form.models.chatModelId, + image_gen_model_id: form.models.imageConfigId, + vision_model_id: form.models.visionConfigId, }, } : {}), @@ -251,7 +251,7 @@ function buildDefinition(form: BuilderForm): AutomationDefinition { /** True once every model slot holds a concrete (non-zero) id. */ export function hasResolvedModels(models: BuilderModels): boolean { - return models.agentLlmId !== 0 && models.imageConfigId !== 0 && models.visionConfigId !== 0; + return models.chatModelId !== 0 && models.imageConfigId !== 0 && models.visionConfigId !== 0; } /** The desired schedule trigger for this form, or ``null`` if none. */ @@ -500,9 +500,9 @@ function modelsFromDefinition(raw: unknown): BuilderModels { const m = asRecord(raw); const num = (value: unknown) => (typeof value === "number" ? value : 0); return { - agentLlmId: num(m.agent_llm_id), - imageConfigId: num(m.image_generation_config_id), - visionConfigId: num(m.vision_llm_config_id), + chatModelId: num(m.chat_model_id), + imageConfigId: num(m.image_gen_model_id), + visionConfigId: num(m.vision_model_id), }; } From 0674accc23c3474758b39203752ce8bb5c5d1f1d Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Wed, 10 Jun 2026 21:49:21 +0530 Subject: [PATCH 007/212] feat(web): add model connection client data layer --- .../model-connections-mutation.atoms.ts | 129 ++++++++++++++++++ .../model-connections-query.atoms.ts | 32 +++++ .../types/model-connections.types.ts | 98 +++++++++++++ .../lib/apis/model-connections-api.service.ts | 88 ++++++++++++ surfsense_web/lib/query-client/cache-keys.ts | 5 + 5 files changed, 352 insertions(+) create mode 100644 surfsense_web/atoms/model-connections/model-connections-mutation.atoms.ts create mode 100644 surfsense_web/atoms/model-connections/model-connections-query.atoms.ts create mode 100644 surfsense_web/contracts/types/model-connections.types.ts create mode 100644 surfsense_web/lib/apis/model-connections-api.service.ts diff --git a/surfsense_web/atoms/model-connections/model-connections-mutation.atoms.ts b/surfsense_web/atoms/model-connections/model-connections-mutation.atoms.ts new file mode 100644 index 000000000..7d58a402c --- /dev/null +++ b/surfsense_web/atoms/model-connections/model-connections-mutation.atoms.ts @@ -0,0 +1,129 @@ +import { atomWithMutation } from "jotai-tanstack-query"; +import { toast } from "sonner"; +import type { + ConnectionCreateRequest, + ConnectionUpdateRequest, + ModelRoles, + ModelUpdateRequest, +} from "@/contracts/types/model-connections.types"; +import { modelConnectionsApiService } from "@/lib/apis/model-connections-api.service"; +import { cacheKeys } from "@/lib/query-client/cache-keys"; +import { queryClient } from "@/lib/query-client/client"; +import { activeSearchSpaceIdAtom } from "../search-spaces/search-space-query.atoms"; + +function invalidateModelConnections(searchSpaceId: number) { + queryClient.invalidateQueries({ + queryKey: cacheKeys.modelConnections.all(searchSpaceId), + }); + queryClient.invalidateQueries({ + queryKey: cacheKeys.modelConnections.roles(searchSpaceId), + }); +} + +export const createModelConnectionMutationAtom = atomWithMutation((get) => { + const searchSpaceId = Number(get(activeSearchSpaceIdAtom)); + return { + mutationKey: ["model-connections", "create"], + mutationFn: (request: ConnectionCreateRequest) => + modelConnectionsApiService.createConnection(request), + onSuccess: () => { + toast.success("Connection created"); + invalidateModelConnections(searchSpaceId); + }, + onError: (error: Error) => toast.error(error.message || "Failed to create connection"), + }; +}); + +export const updateModelConnectionMutationAtom = atomWithMutation((get) => { + const searchSpaceId = Number(get(activeSearchSpaceIdAtom)); + return { + mutationKey: ["model-connections", "update"], + mutationFn: ({ id, data }: { id: number; data: ConnectionUpdateRequest }) => + modelConnectionsApiService.updateConnection(id, data), + onSuccess: () => { + toast.success("Connection updated"); + invalidateModelConnections(searchSpaceId); + }, + onError: (error: Error) => toast.error(error.message || "Failed to update connection"), + }; +}); + +export const deleteModelConnectionMutationAtom = atomWithMutation((get) => { + const searchSpaceId = Number(get(activeSearchSpaceIdAtom)); + return { + mutationKey: ["model-connections", "delete"], + mutationFn: (id: number) => modelConnectionsApiService.deleteConnection(id), + onSuccess: () => { + toast.success("Connection deleted"); + invalidateModelConnections(searchSpaceId); + }, + onError: (error: Error) => toast.error(error.message || "Failed to delete connection"), + }; +}); + +export const verifyModelConnectionMutationAtom = atomWithMutation((get) => { + const searchSpaceId = Number(get(activeSearchSpaceIdAtom)); + return { + mutationKey: ["model-connections", "verify"], + mutationFn: (id: number) => modelConnectionsApiService.verifyConnection(id), + onSuccess: (result) => { + if (result.ok) toast.success("Connection verified"); + else toast.error(result.message || "Connection failed"); + invalidateModelConnections(searchSpaceId); + }, + onError: (error: Error) => toast.error(error.message || "Failed to verify connection"), + }; +}); + +export const discoverConnectionModelsMutationAtom = atomWithMutation((get) => { + const searchSpaceId = Number(get(activeSearchSpaceIdAtom)); + return { + mutationKey: ["model-connections", "discover"], + mutationFn: (id: number) => modelConnectionsApiService.discoverModels(id), + onSuccess: () => { + toast.success("Models discovered"); + invalidateModelConnections(searchSpaceId); + }, + onError: (error: Error) => toast.error(error.message || "Failed to discover models"), + }; +}); + +export const updateModelMutationAtom = atomWithMutation((get) => { + const searchSpaceId = Number(get(activeSearchSpaceIdAtom)); + return { + mutationKey: ["models", "update"], + mutationFn: ({ id, data }: { id: number; data: ModelUpdateRequest }) => + modelConnectionsApiService.updateModel(id, data), + onSuccess: () => invalidateModelConnections(searchSpaceId), + onError: (error: Error) => toast.error(error.message || "Failed to update model"), + }; +}); + +export const testModelMutationAtom = atomWithMutation((get) => { + const searchSpaceId = Number(get(activeSearchSpaceIdAtom)); + return { + mutationKey: ["models", "test"], + mutationFn: (id: number) => modelConnectionsApiService.testModel(id), + onSuccess: (result) => { + if (result.ok) toast.success("Model test succeeded"); + else toast.error(result.message || "Model test failed"); + invalidateModelConnections(searchSpaceId); + }, + onError: (error: Error) => toast.error(error.message || "Failed to test model"), + }; +}); + +export const updateModelRolesMutationAtom = atomWithMutation((get) => { + const searchSpaceId = Number(get(activeSearchSpaceIdAtom)); + return { + mutationKey: ["model-roles", "update"], + mutationFn: (roles: ModelRoles) => + modelConnectionsApiService.updateModelRoles(searchSpaceId, roles), + onSuccess: () => { + queryClient.invalidateQueries({ + queryKey: cacheKeys.modelConnections.roles(searchSpaceId), + }); + }, + onError: (error: Error) => toast.error(error.message || "Failed to update model roles"), + }; +}); diff --git a/surfsense_web/atoms/model-connections/model-connections-query.atoms.ts b/surfsense_web/atoms/model-connections/model-connections-query.atoms.ts new file mode 100644 index 000000000..617ffe124 --- /dev/null +++ b/surfsense_web/atoms/model-connections/model-connections-query.atoms.ts @@ -0,0 +1,32 @@ +import { atomWithQuery } from "jotai-tanstack-query"; +import { modelConnectionsApiService } from "@/lib/apis/model-connections-api.service"; +import { getBearerToken } from "@/lib/auth-utils"; +import { cacheKeys } from "@/lib/query-client/cache-keys"; +import { activeSearchSpaceIdAtom } from "../search-spaces/search-space-query.atoms"; + +export const globalModelConnectionsAtom = atomWithQuery(() => ({ + queryKey: cacheKeys.modelConnections.global(), + enabled: !!getBearerToken(), + staleTime: 10 * 60 * 1000, + queryFn: () => modelConnectionsApiService.getGlobalConnections(), +})); + +export const modelConnectionsAtom = atomWithQuery((get) => { + const searchSpaceId = Number(get(activeSearchSpaceIdAtom)); + return { + queryKey: cacheKeys.modelConnections.all(searchSpaceId), + enabled: !!searchSpaceId, + staleTime: 5 * 60 * 1000, + queryFn: () => modelConnectionsApiService.getConnections(searchSpaceId), + }; +}); + +export const modelRolesAtom = atomWithQuery((get) => { + const searchSpaceId = Number(get(activeSearchSpaceIdAtom)); + return { + queryKey: cacheKeys.modelConnections.roles(searchSpaceId), + enabled: !!searchSpaceId, + staleTime: 5 * 60 * 1000, + queryFn: () => modelConnectionsApiService.getModelRoles(searchSpaceId), + }; +}); diff --git a/surfsense_web/contracts/types/model-connections.types.ts b/surfsense_web/contracts/types/model-connections.types.ts new file mode 100644 index 000000000..14f93c61a --- /dev/null +++ b/surfsense_web/contracts/types/model-connections.types.ts @@ -0,0 +1,98 @@ +import { z } from "zod"; + +export const connectionProtocolEnum = z.enum(["OLLAMA", "OPENAI_COMPATIBLE", "NATIVE"]); +export const connectionScopeEnum = z.enum(["GLOBAL", "SEARCH_SPACE", "USER"]); +export const modelSourceEnum = z.enum(["DISCOVERED", "MANUAL"]); + +export const modelCapabilities = z.object({ + chat: z.boolean().optional(), + vision: z.boolean().optional(), + image_gen: z.boolean().optional(), + embedding: z.boolean().optional(), + tools: z.boolean().optional(), +}); + +export const modelRead = z.object({ + id: z.number(), + connection_id: z.number(), + model_id: z.string(), + display_name: z.string().nullable().optional(), + source: z.union([modelSourceEnum, z.string()]), + capabilities: z.record(z.string(), z.any()).default({}), + capabilities_declared: z.record(z.string(), z.any()).default({}), + capabilities_verified: z.record(z.string(), z.any()).default({}), + capabilities_override: z.record(z.string(), z.any()).default({}), + embedding_dimension: z.number().nullable().optional(), + enabled: z.boolean(), + billing_tier: z.string().nullable().optional(), + catalog: z.record(z.string(), z.any()).default({}), + created_at: z.string().nullable().optional(), +}); + +export const connectionRead = z.object({ + id: z.number(), + protocol: z.union([connectionProtocolEnum, z.string()]), + native_provider: z.string().nullable().optional(), + base_url: z.string().nullable().optional(), + extra: z.record(z.string(), z.any()).default({}), + scope: z.union([connectionScopeEnum, z.string()]), + search_space_id: z.number().nullable().optional(), + user_id: z.string().nullable().optional(), + enabled: z.boolean(), + has_api_key: z.boolean(), + last_verified_at: z.string().nullable().optional(), + last_status: z.string().nullable().optional(), + last_error: z.string().nullable().optional(), + models: z.array(modelRead).default([]), + created_at: z.string().nullable().optional(), +}); + +export const connectionCreateRequest = z.object({ + protocol: connectionProtocolEnum, + native_provider: z.string().nullable().optional(), + base_url: z.string().nullable().optional(), + api_key: z.string().nullable().optional(), + extra: z.record(z.string(), z.any()).default({}), + scope: connectionScopeEnum.default("SEARCH_SPACE"), + search_space_id: z.number().nullable().optional(), + enabled: z.boolean().default(true), +}); + +export const connectionUpdateRequest = z.object({ + native_provider: z.string().nullable().optional(), + base_url: z.string().nullable().optional(), + api_key: z.string().nullable().optional(), + extra: z.record(z.string(), z.any()).optional(), + enabled: z.boolean().optional(), +}); + +export const modelUpdateRequest = z.object({ + display_name: z.string().nullable().optional(), + enabled: z.boolean().optional(), + capabilities_override: z.record(z.string(), z.any()).optional(), +}); + +export const verifyConnectionResponse = z.object({ + status: z.string(), + ok: z.boolean(), + message: z.string().default(""), +}); + +export const modelRoles = z.object({ + chat_model_id: z.number().nullable().optional(), + vision_model_id: z.number().nullable().optional(), + image_gen_model_id: z.number().nullable().optional(), +}); + +export const connectionListResponse = z.array(connectionRead); +export const modelListResponse = z.array(modelRead); + +export type ConnectionProtocol = z.infer; +export type ConnectionScope = z.infer; +export type ModelRead = z.infer; +export type ConnectionRead = z.infer; +export type ConnectionCreateRequest = z.infer; +export type ConnectionUpdateRequest = z.infer; +export type ModelUpdateRequest = z.infer; +export type ModelRoles = z.infer; +export type VerifyConnectionResponse = z.infer; diff --git a/surfsense_web/lib/apis/model-connections-api.service.ts b/surfsense_web/lib/apis/model-connections-api.service.ts new file mode 100644 index 000000000..ca92ad11b --- /dev/null +++ b/surfsense_web/lib/apis/model-connections-api.service.ts @@ -0,0 +1,88 @@ +import { + type ConnectionCreateRequest, + type ConnectionUpdateRequest, + connectionCreateRequest, + connectionListResponse, + connectionRead, + connectionUpdateRequest, + type ModelRoles, + type ModelUpdateRequest, + modelListResponse, + modelRead, + modelRoles, + modelUpdateRequest, + verifyConnectionResponse, +} from "@/contracts/types/model-connections.types"; +import { ValidationError } from "../error"; +import { baseApiService } from "./base-api.service"; + +class ModelConnectionsApiService { + getGlobalConnections = async () => { + return baseApiService.get(`/api/v1/global-model-connections`, connectionListResponse); + }; + + getConnections = async (searchSpaceId: number) => { + return baseApiService.get( + `/api/v1/model-connections?search_space_id=${searchSpaceId}`, + connectionListResponse + ); + }; + + createConnection = async (request: ConnectionCreateRequest) => { + const parsed = connectionCreateRequest.safeParse(request); + if (!parsed.success) { + throw new ValidationError(parsed.error.issues.map((issue) => issue.message).join(", ")); + } + return baseApiService.post(`/api/v1/model-connections`, connectionRead, { + body: parsed.data, + }); + }; + + updateConnection = async (id: number, request: ConnectionUpdateRequest) => { + const parsed = connectionUpdateRequest.safeParse(request); + if (!parsed.success) { + throw new ValidationError(parsed.error.issues.map((issue) => issue.message).join(", ")); + } + return baseApiService.put(`/api/v1/model-connections/${id}`, connectionRead, { + body: parsed.data, + }); + }; + + deleteConnection = async (id: number) => { + return baseApiService.delete(`/api/v1/model-connections/${id}`, undefined); + }; + + verifyConnection = async (id: number) => { + return baseApiService.post(`/api/v1/model-connections/${id}/verify`, verifyConnectionResponse); + }; + + discoverModels = async (id: number) => { + return baseApiService.post(`/api/v1/model-connections/${id}/discover`, modelListResponse); + }; + + updateModel = async (id: number, request: ModelUpdateRequest) => { + const parsed = modelUpdateRequest.safeParse(request); + if (!parsed.success) { + throw new ValidationError(parsed.error.issues.map((issue) => issue.message).join(", ")); + } + return baseApiService.put(`/api/v1/models/${id}`, modelRead, { + body: parsed.data, + }); + }; + + testModel = async (id: number) => { + return baseApiService.post(`/api/v1/models/${id}/test`, verifyConnectionResponse); + }; + + getModelRoles = async (searchSpaceId: number) => { + return baseApiService.get(`/api/v1/search-spaces/${searchSpaceId}/model-roles`, modelRoles); + }; + + updateModelRoles = async (searchSpaceId: number, roles: ModelRoles) => { + return baseApiService.put(`/api/v1/search-spaces/${searchSpaceId}/model-roles`, modelRoles, { + body: roles, + }); + }; +} + +export const modelConnectionsApiService = new ModelConnectionsApiService(); diff --git a/surfsense_web/lib/query-client/cache-keys.ts b/surfsense_web/lib/query-client/cache-keys.ts index 6f8885d7e..558a73f95 100644 --- a/surfsense_web/lib/query-client/cache-keys.ts +++ b/surfsense_web/lib/query-client/cache-keys.ts @@ -44,6 +44,11 @@ export const cacheKeys = { global: () => ["new-llm-configs", "global"] as const, modelList: () => ["models", "catalogue"] as const, }, + modelConnections: { + all: (searchSpaceId: number) => ["model-connections", searchSpaceId] as const, + global: () => ["model-connections", "global"] as const, + roles: (searchSpaceId: number) => ["model-roles", searchSpaceId] as const, + }, imageGenConfigs: { all: (searchSpaceId: number) => ["image-gen-configs", searchSpaceId] as const, byId: (configId: number) => ["image-gen-configs", "detail", configId] as const, From 4bda0ffa9691c3069798d58a095dac58d9dd2042 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Wed, 10 Jun 2026 21:49:33 +0530 Subject: [PATCH 008/212] feat(settings): add model connection management UI --- .../search-space-settings/layout-shell.tsx | 23 +- .../search-space-settings/models/page.tsx | 4 +- .../components/settings/llm-role-manager.tsx | 4 +- .../settings/model-connections-settings.tsx | 384 ++++++++++++++++++ 4 files changed, 389 insertions(+), 26 deletions(-) create mode 100644 surfsense_web/components/settings/model-connections-settings.tsx diff --git a/surfsense_web/app/dashboard/[search_space_id]/search-space-settings/layout-shell.tsx b/surfsense_web/app/dashboard/[search_space_id]/search-space-settings/layout-shell.tsx index 22f68edab..9d9045004 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/search-space-settings/layout-shell.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/search-space-settings/layout-shell.tsx @@ -5,9 +5,6 @@ import { Bot, CircleUser, Earth, - ImageIcon, - ListChecks, - ScanEye, UserKey, } from "lucide-react"; import Link from "next/link"; @@ -20,10 +17,7 @@ import { cn } from "@/lib/utils"; export type SearchSpaceSettingsTab = | "general" - | "roles" | "models" - | "image-models" - | "vision-models" | "team-roles" | "prompts" | "public-links"; @@ -57,26 +51,11 @@ export function SearchSpaceSettingsLayoutShell({ label: t("nav_general"), icon: , }, - { - value: "roles" as const, - label: t("nav_role_assignments"), - icon: , - }, { value: "models" as const, - label: t("nav_agent_models"), + label: t("nav_models"), icon: , }, - { - value: "image-models" as const, - label: t("nav_image_models"), - icon: , - }, - { - value: "vision-models" as const, - label: t("nav_vision_models"), - icon: , - }, { value: "team-roles" as const, label: t("nav_team_roles"), diff --git a/surfsense_web/app/dashboard/[search_space_id]/search-space-settings/models/page.tsx b/surfsense_web/app/dashboard/[search_space_id]/search-space-settings/models/page.tsx index d68194782..c97ef7630 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/search-space-settings/models/page.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/search-space-settings/models/page.tsx @@ -1,6 +1,6 @@ -import { AgentModelManager } from "@/components/settings/agent-model-manager"; +import { ModelConnectionsSettings } from "@/components/settings/model-connections-settings"; export default async function Page({ params }: { params: Promise<{ search_space_id: string }> }) { const { search_space_id } = await params; - return ; + return ; } diff --git a/surfsense_web/components/settings/llm-role-manager.tsx b/surfsense_web/components/settings/llm-role-manager.tsx index c32e79a8e..547675927 100644 --- a/surfsense_web/components/settings/llm-role-manager.tsx +++ b/surfsense_web/components/settings/llm-role-manager.tsx @@ -48,8 +48,8 @@ import { cn } from "@/lib/utils"; const ROLE_DESCRIPTIONS = { agent: { icon: Bot, - title: "Agent LLM", - description: "Primary LLM for chat interactions and agent operations", + title: "Chat model", + description: "Primary model for chat interactions and agent operations", color: "text-muted-foreground", bgColor: "bg-muted", prefKey: "agent_llm_id" as const, diff --git a/surfsense_web/components/settings/model-connections-settings.tsx b/surfsense_web/components/settings/model-connections-settings.tsx new file mode 100644 index 000000000..5fa4cccf7 --- /dev/null +++ b/surfsense_web/components/settings/model-connections-settings.tsx @@ -0,0 +1,384 @@ +"use client"; + +import { useAtom, useAtomValue } from "jotai"; +import { CheckCircle2, PlugZap, RefreshCcw, XCircle } from "lucide-react"; +import { useMemo, useState } from "react"; +import { + createModelConnectionMutationAtom, + discoverConnectionModelsMutationAtom, + testModelMutationAtom, + updateModelMutationAtom, + updateModelRolesMutationAtom, + verifyModelConnectionMutationAtom, +} from "@/atoms/model-connections/model-connections-mutation.atoms"; +import { + globalModelConnectionsAtom, + modelConnectionsAtom, + modelRolesAtom, +} from "@/atoms/model-connections/model-connections-query.atoms"; +import { Badge } from "@/components/ui/badge"; +import { Button } from "@/components/ui/button"; +import { Card, CardContent, CardDescription, CardHeader, CardTitle } from "@/components/ui/card"; +import { Input } from "@/components/ui/input"; +import { Label } from "@/components/ui/label"; +import { + Select, + SelectContent, + SelectItem, + SelectTrigger, + SelectValue, +} from "@/components/ui/select"; +import type { + ConnectionProtocol, + ConnectionRead, + ModelRead, +} from "@/contracts/types/model-connections.types"; +import { isCloud } from "@/lib/env-config"; +import { getProviderIcon } from "@/lib/provider-icons"; + +type Preset = { + id: string; + label: string; + protocol: ConnectionProtocol; + nativeProvider?: string; + baseUrl?: string; + local?: boolean; +}; + +const PRESETS: Preset[] = [ + { id: "openai", label: "OpenAI", protocol: "NATIVE", nativeProvider: "OPENAI" }, + { id: "anthropic", label: "Anthropic", protocol: "NATIVE", nativeProvider: "ANTHROPIC" }, + { id: "openrouter", label: "OpenRouter", protocol: "NATIVE", nativeProvider: "OPENROUTER" }, + { + id: "ollama", + label: "Ollama", + protocol: "OLLAMA", + baseUrl: "http://host.docker.internal:11434", + local: true, + }, + { + id: "lmstudio", + label: "LM Studio", + protocol: "OPENAI_COMPATIBLE", + baseUrl: "http://host.docker.internal:1234/v1", + local: true, + }, + { + id: "llamacpp", + label: "llama.cpp", + protocol: "OPENAI_COMPATIBLE", + baseUrl: "http://host.docker.internal:8080/v1", + local: true, + }, + { + id: "localai", + label: "LocalAI", + protocol: "OPENAI_COMPATIBLE", + baseUrl: "http://host.docker.internal:8080/v1", + local: true, + }, + { + id: "vllm", + label: "vLLM", + protocol: "OPENAI_COMPATIBLE", + baseUrl: "http://host.docker.internal:8000/v1", + local: true, + }, +]; + +function modelLabel(model: ModelRead) { + return model.display_name || model.model_id; +} + +function capability(model: ModelRead, key: "chat" | "vision" | "image_gen") { + return Boolean(model.capabilities?.[key]); +} + +function StatusBadge({ connection }: { connection: ConnectionRead }) { + if (connection.last_status === "OK") { + return ( + + Healthy + + ); + } + if (connection.last_status) { + return ( + + {connection.last_status} + + ); + } + return Not tested; +} + +function flattenModels(connections: ConnectionRead[]) { + return connections.flatMap((connection) => + connection.models.map((model) => ({ + ...model, + connectionName: connection.native_provider || connection.protocol, + connectionId: connection.id, + provider: connection.native_provider || connection.protocol, + })) + ); +} + +export function ModelConnectionsSettings({ searchSpaceId }: { searchSpaceId: number }) { + const [{ data: globalConnections = [] }] = useAtom(globalModelConnectionsAtom); + const [{ data: connections = [] }] = useAtom(modelConnectionsAtom); + const [{ data: roles }] = useAtom(modelRolesAtom); + const createConnection = useAtomValue(createModelConnectionMutationAtom); + const verifyConnection = useAtomValue(verifyModelConnectionMutationAtom); + const discoverModels = useAtomValue(discoverConnectionModelsMutationAtom); + const updateModel = useAtomValue(updateModelMutationAtom); + const testModel = useAtomValue(testModelMutationAtom); + const updateRoles = useAtomValue(updateModelRolesMutationAtom); + + const visiblePresets = useMemo( + () => PRESETS.filter((preset) => !(isCloud() && preset.local)), + [] + ); + const [presetId, setPresetId] = useState(visiblePresets[0]?.id ?? "openai"); + const preset = visiblePresets.find((item) => item.id === presetId) ?? visiblePresets[0]; + const [baseUrl, setBaseUrl] = useState(preset?.baseUrl ?? ""); + const [apiKey, setApiKey] = useState(""); + + const allConnections = [...globalConnections, ...connections]; + const enabledModels = flattenModels(allConnections).filter((model) => model.enabled); + const chatModels = enabledModels.filter((model) => capability(model, "chat")); + const visionModels = enabledModels.filter((model) => capability(model, "vision")); + const imageModels = enabledModels.filter((model) => capability(model, "image_gen")); + + function onPresetChange(value: string) { + setPresetId(value); + const next = visiblePresets.find((item) => item.id === value); + setBaseUrl(next?.baseUrl ?? ""); + } + + function handleCreate() { + if (!preset) return; + createConnection.mutate({ + protocol: preset.protocol, + native_provider: preset.nativeProvider, + base_url: baseUrl || null, + api_key: apiKey || null, + scope: "SEARCH_SPACE", + search_space_id: searchSpaceId, + extra: {}, + enabled: true, + }); + } + + function renderModelOption(model: ModelRead & { connectionName: string; provider: string }) { + return ( + + + {getProviderIcon(model.provider, { className: "size-4" })} + {modelLabel(model)} · {model.connectionName} + + + ); + } + + return ( +
+ + + Model Connections + + Add credentials or local endpoints once, then discover reusable models. + + + +
+
+ + +
+
+ + setBaseUrl(event.target.value)} + placeholder="https://api.example.com/v1" + /> +
+
+ + setApiKey(event.target.value)} + placeholder="Optional for local models" + type="password" + /> +
+
+ +
+
+ {preset?.local ? ( +

+ Local URLs are tested from the backend container. Use host.docker.internal instead of + localhost. +

+ ) : null} + +
+ {connections.map((connection) => ( +
+
+
+
+ {getProviderIcon(connection.native_provider || connection.protocol, { + className: "size-4", + })} + {connection.native_provider || connection.protocol} +
+
+ {connection.base_url || "Provider default endpoint"} +
+
+
+ + + +
+
+ {connection.last_error ? ( +

{connection.last_error}

+ ) : null} +
+ {connection.models.map((model) => ( +
+
+
+ {getProviderIcon(connection.native_provider || connection.protocol, { + className: "size-4", + })} + {modelLabel(model)} +
+
+ {["chat", "vision", "image_gen"] + .filter((key) => Boolean(model.capabilities?.[key])) + .join(", ") || "No verified capabilities"} +
+
+
+ + +
+
+ ))} +
+
+ ))} +
+
+
+ + + + Model Roles + + Pick which enabled model powers chat, vision, and image generation for this search + space. + + + +
+ + +
+
+ + +
+
+ + +
+
+
+
+ ); +} From 39cca36c31bd17483b365951b361c0ff4e70c4ec Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Wed, 10 Jun 2026 21:49:55 +0530 Subject: [PATCH 009/212] feat(onboarding): include model connections in setup flow --- .../[search_space_id]/client-layout.tsx | 56 ++--- .../[search_space_id]/onboard/page.tsx | 216 +++++------------- surfsense_web/lib/posthog/events.ts | 12 +- 3 files changed, 96 insertions(+), 188 deletions(-) diff --git a/surfsense_web/app/dashboard/[search_space_id]/client-layout.tsx b/surfsense_web/app/dashboard/[search_space_id]/client-layout.tsx index 3a41b5998..c7e05fe99 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/client-layout.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/client-layout.tsx @@ -4,15 +4,15 @@ import { useAtomValue, useSetAtom } from "jotai"; import { useParams, usePathname, useRouter } from "next/navigation"; import { useTranslations } from "next-intl"; import type React from "react"; -import { useCallback, useEffect, useRef, useState } from "react"; +import { useCallback, useEffect, useMemo, useRef, useState } from "react"; import { toast } from "sonner"; import { pendingUserImageDataUrlsAtom } from "@/atoms/chat/pending-user-images.atom"; import { myAccessAtom } from "@/atoms/members/members-query.atoms"; -import { updateLLMPreferencesMutationAtom } from "@/atoms/new-llm-config/new-llm-config-mutation.atoms"; +import { updateModelRolesMutationAtom } from "@/atoms/model-connections/model-connections-mutation.atoms"; import { - globalNewLLMConfigsAtom, - llmPreferencesAtom, -} from "@/atoms/new-llm-config/new-llm-config-query.atoms"; + globalModelConnectionsAtom, + modelRolesAtom, +} from "@/atoms/model-connections/model-connections-query.atoms"; import { activeSearchSpaceIdAtom } from "@/atoms/search-spaces/search-space-query.atoms"; import { DocumentUploadDialogProvider } from "@/components/assistant-ui/document-upload-popup"; import { LayoutDataProvider } from "@/components/layout"; @@ -21,7 +21,6 @@ import { Card, CardContent, CardDescription, CardHeader, CardTitle } from "@/com import { useFolderSync } from "@/hooks/use-folder-sync"; import { useGlobalLoadingEffect } from "@/hooks/use-global-loading"; import { useElectronAPI } from "@/hooks/use-platform"; -import { isLlmOnboardingComplete } from "@/lib/onboarding"; export function DashboardClientLayout({ children, @@ -38,18 +37,27 @@ export function DashboardClientLayout({ const setPendingUserImageUrls = useSetAtom(pendingUserImageDataUrlsAtom); const { - data: preferences = {}, + data: modelRoles = {}, isFetching: loading, error, - refetch: refetchPreferences, - } = useAtomValue(llmPreferencesAtom); - const { data: globalConfigs = [], isFetching: globalConfigsLoading } = - useAtomValue(globalNewLLMConfigsAtom); - const { mutateAsync: updatePreferences } = useAtomValue(updateLLMPreferencesMutationAtom); + refetch: refetchModelRoles, + } = useAtomValue(modelRolesAtom); + const { data: globalConnections = [], isFetching: globalConfigsLoading } = useAtomValue( + globalModelConnectionsAtom + ); + const { mutateAsync: updateModelRoles } = useAtomValue(updateModelRolesMutationAtom); + + const firstGlobalChatModel = useMemo(() => { + for (const connection of globalConnections) { + const model = connection.models.find((item) => item.enabled && item.capabilities?.chat); + if (model) return model; + } + return null; + }, [globalConnections]); const isOnboardingComplete = useCallback(() => { - return isLlmOnboardingComplete(preferences.agent_llm_id, globalConfigs.length > 0); - }, [preferences.agent_llm_id, globalConfigs.length]); + return (modelRoles.chat_model_id ?? 0) !== 0 || Boolean(firstGlobalChatModel); + }, [modelRoles.chat_model_id, firstGlobalChatModel]); const { data: access = null, isLoading: accessLoading } = useAtomValue(myAccessAtom); const [hasCheckedOnboarding, setHasCheckedOnboarding] = useState(false); @@ -84,24 +92,18 @@ export function DashboardClientLayout({ return; } - if (globalConfigs.length > 0 && !hasAttemptedAutoConfig.current) { + if (firstGlobalChatModel && !hasAttemptedAutoConfig.current) { hasAttemptedAutoConfig.current = true; setIsAutoConfiguring(true); const autoConfigureWithGlobal = async () => { try { - const firstGlobalConfig = globalConfigs[0]; - await updatePreferences({ - search_space_id: Number(searchSpaceId), - data: { - agent_llm_id: firstGlobalConfig.id, - }, - }); + await updateModelRoles({ chat_model_id: firstGlobalChatModel.id }); - await refetchPreferences(); + await refetchModelRoles(); toast.success("AI configured automatically!", { - description: `Using ${firstGlobalConfig.name}. Customize in Settings.`, + description: `Using ${firstGlobalChatModel.display_name || firstGlobalChatModel.model_id}. Customize in Settings.`, }); setHasCheckedOnboarding(true); @@ -128,12 +130,12 @@ export function DashboardClientLayout({ isOnboardingPage, isOwner, isAutoConfiguring, - globalConfigs, + firstGlobalChatModel, router, searchSpaceId, hasCheckedOnboarding, - updatePreferences, - refetchPreferences, + updateModelRoles, + refetchModelRoles, ]); const electronAPI = useElectronAPI(); diff --git a/surfsense_web/app/dashboard/[search_space_id]/onboard/page.tsx b/surfsense_web/app/dashboard/[search_space_id]/onboard/page.tsx index de5c961e8..9cf429a3a 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/onboard/page.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/onboard/page.tsx @@ -2,193 +2,99 @@ import { useAtomValue } from "jotai"; import { useParams, useRouter } from "next/navigation"; -import { useEffect, useRef, useState } from "react"; +import { useEffect, useMemo, useRef, useState } from "react"; import { toast } from "sonner"; +import { updateModelRolesMutationAtom } from "@/atoms/model-connections/model-connections-mutation.atoms"; import { - createNewLLMConfigMutationAtom, - updateLLMPreferencesMutationAtom, -} from "@/atoms/new-llm-config/new-llm-config-mutation.atoms"; -import { - globalNewLLMConfigsAtom, - llmPreferencesAtom, -} from "@/atoms/new-llm-config/new-llm-config-query.atoms"; + globalModelConnectionsAtom, + modelRolesAtom, +} from "@/atoms/model-connections/model-connections-query.atoms"; import { Logo } from "@/components/Logo"; -import { LLMConfigForm, type LLMConfigFormData } from "@/components/shared/llm-config-form"; import { Button } from "@/components/ui/button"; import { Spinner } from "@/components/ui/spinner"; import { useGlobalLoadingEffect } from "@/hooks/use-global-loading"; import { getBearerToken, redirectToLogin } from "@/lib/auth-utils"; -import { isLlmOnboardingComplete } from "@/lib/onboarding"; export default function OnboardPage() { const router = useRouter(); const params = useParams(); const searchSpaceId = Number(params.search_space_id); - // Queries - const { - data: globalConfigs = [], - isFetching: globalConfigsLoading, - isSuccess: globalConfigsLoaded, - } = useAtomValue(globalNewLLMConfigsAtom); - const { data: preferences = {}, isFetching: preferencesLoading } = - useAtomValue(llmPreferencesAtom); - - // Mutations - const { mutateAsync: createConfig, isPending: isCreating } = useAtomValue( - createNewLLMConfigMutationAtom + const { data: globalConnections = [], isFetching: globalLoading } = useAtomValue( + globalModelConnectionsAtom ); - const { mutateAsync: updatePreferences, isPending: isUpdatingPreferences } = useAtomValue( - updateLLMPreferencesMutationAtom - ); - - // State + const { data: roles = {}, isFetching: rolesLoading } = useAtomValue(modelRolesAtom); + const { mutateAsync: updateRoles, isPending } = useAtomValue(updateModelRolesMutationAtom); const [isAutoConfiguring, setIsAutoConfiguring] = useState(false); const hasAttemptedAutoConfig = useRef(false); - // Check authentication useEffect(() => { - const token = getBearerToken(); - if (!token) { - redirectToLogin(); - } + if (!getBearerToken()) redirectToLogin(); }, []); - const isOnboardingComplete = isLlmOnboardingComplete( - preferences.agent_llm_id, - globalConfigs.length > 0 - ); - - useEffect(() => { - if (!preferencesLoading && globalConfigsLoaded && isOnboardingComplete) { - router.push(`/dashboard/${searchSpaceId}/new-chat`); + const firstGlobalChatModel = useMemo(() => { + for (const connection of globalConnections) { + const model = connection.models.find((item) => item.enabled && item.capabilities?.chat); + if (model) return model; } - }, [preferencesLoading, globalConfigsLoaded, isOnboardingComplete, router, searchSpaceId]); + return null; + }, [globalConnections]); + + const isComplete = (roles.chat_model_id ?? 0) !== 0 || Boolean(firstGlobalChatModel); useEffect(() => { - const autoConfigureWithGlobal = async () => { - if (hasAttemptedAutoConfig.current) return; - if (globalConfigsLoading || preferencesLoading) return; - if (!globalConfigsLoaded) return; - if (isOnboardingComplete) return; + if (globalLoading || rolesLoading || hasAttemptedAutoConfig.current) return; + if ((roles.chat_model_id ?? 0) !== 0) { + router.push(`/dashboard/${searchSpaceId}/new-chat`); + return; + } + if (!firstGlobalChatModel) return; - if (globalConfigs.length > 0) { - hasAttemptedAutoConfig.current = true; - setIsAutoConfiguring(true); - - try { - const firstGlobalConfig = globalConfigs[0]; - - await updatePreferences({ - search_space_id: searchSpaceId, - data: { - agent_llm_id: firstGlobalConfig.id, - }, - }); - - toast.success("AI configured automatically!", { - description: `Using ${firstGlobalConfig.name}. You can customize this later in Settings.`, - }); - - router.push(`/dashboard/${searchSpaceId}/new-chat`); - } catch (error) { - console.error("Auto-configuration failed:", error); - toast.error("Auto-configuration failed. Please add a configuration manually."); - setIsAutoConfiguring(false); - } - } - }; - - autoConfigureWithGlobal(); + hasAttemptedAutoConfig.current = true; + setIsAutoConfiguring(true); + updateRoles({ chat_model_id: firstGlobalChatModel.id }) + .then(() => { + toast.success("AI configured automatically", { + description: `Using ${firstGlobalChatModel.display_name || firstGlobalChatModel.model_id}.`, + }); + router.push(`/dashboard/${searchSpaceId}/new-chat`); + }) + .catch((error) => { + console.error("Auto-configuration failed:", error); + toast.error("Auto-configuration failed. Add a connection manually."); + setIsAutoConfiguring(false); + }); }, [ - globalConfigs, - globalConfigsLoading, - globalConfigsLoaded, - preferencesLoading, - isOnboardingComplete, - updatePreferences, - searchSpaceId, + firstGlobalChatModel, + globalLoading, + roles.chat_model_id, + rolesLoading, router, + searchSpaceId, + updateRoles, ]); - const handleSubmit = async (formData: LLMConfigFormData) => { - try { - const newConfig = await createConfig(formData); - - await updatePreferences({ - search_space_id: searchSpaceId, - data: { - agent_llm_id: newConfig.id, - }, - }); - - toast.success("Configuration created!", { - description: "Redirecting to chat...", - }); - - router.push(`/dashboard/${searchSpaceId}/new-chat`); - } catch (error) { - console.error("Failed to create config:", error); - if (error instanceof Error) { - toast.error(error.message || "Failed to create configuration"); - } - } - }; - - const isSubmitting = isCreating || isUpdatingPreferences; - - const isLoading = globalConfigsLoading || preferencesLoading || isAutoConfiguring; + const isLoading = globalLoading || rolesLoading || isAutoConfiguring || isPending; useGlobalLoadingEffect(isLoading); - if (isLoading) { - return null; - } - - if (globalConfigs.length > 0 && !isAutoConfiguring) { - return null; - } + if (isLoading || isComplete) return null; return ( -
-
- {/* Header */} -
- -
-

Configure Your AI

-

- Add your LLM provider to get started with SurfSense -

-
-
- - {/* Form card */} -
- -
- - {/* Footer */} -
- -

You can add more configurations later

+
+
+ +
+

Connect a Model

+

+ Add one connection, discover its models, then choose a chat model for this search space. +

+ + {isPending ? : null}
); diff --git a/surfsense_web/lib/posthog/events.ts b/surfsense_web/lib/posthog/events.ts index 4dc644d5e..5aac8943d 100644 --- a/surfsense_web/lib/posthog/events.ts +++ b/surfsense_web/lib/posthog/events.ts @@ -609,9 +609,9 @@ interface AutomationCreatedProps { task_count?: number; trigger_type?: string; has_schedule?: boolean; - agent_llm_id?: number; - image_generation_config_id?: number; - vision_llm_config_id?: number; + chat_model_id?: number; + image_gen_model_id?: number; + vision_model_id?: number; tags_count?: number; } @@ -705,9 +705,9 @@ interface AutomationChatDecisionProps { edited?: boolean; task_count?: number; trigger_type?: string; - agent_llm_id?: number; - image_generation_config_id?: number; - vision_llm_config_id?: number; + chat_model_id?: number; + image_gen_model_id?: number; + vision_model_id?: number; } export function trackAutomationChatApproved(props: AutomationChatDecisionProps) { From e46283992907102421ef6d95fcdb92f9f3ca989e Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Wed, 10 Jun 2026 21:50:10 +0530 Subject: [PATCH 010/212] refactor(chat): simplify model selector connection flow --- .../components/new-chat/model-selector.tsx | 1542 ++--------------- surfsense_web/content/docs/how-to/ollama.mdx | 2 +- 2 files changed, 173 insertions(+), 1371 deletions(-) diff --git a/surfsense_web/components/new-chat/model-selector.tsx b/surfsense_web/components/new-chat/model-selector.tsx index 0a096f5f8..7c912afbb 100644 --- a/surfsense_web/components/new-chat/model-selector.tsx +++ b/surfsense_web/components/new-chat/model-selector.tsx @@ -1,53 +1,27 @@ "use client"; -import { useAtomValue } from "jotai"; +import { useAtom, useAtomValue } from "jotai"; +import { Bot, Check, ChevronDown, ImageOff, Search, Settings2, Zap } from "lucide-react"; +import { useMemo, useState } from "react"; +import { updateModelRolesMutationAtom } from "@/atoms/model-connections/model-connections-mutation.atoms"; import { - Bot, - Check, - ChevronDown, - ChevronLeft, - ChevronRight, - ChevronUp, - ImageIcon, - Layers, - Pencil, - Plus, - ScanEye, - Search, - Zap, -} from "lucide-react"; -import type React from "react"; -import { Fragment, useCallback, useEffect, useMemo, useRef, useState } from "react"; -import { toast } from "sonner"; -import { pendingUserImageDataUrlsAtom } from "@/atoms/chat/pending-user-images.atom"; -import { - globalImageGenConfigsAtom, - imageGenConfigsAtom, -} from "@/atoms/image-gen-config/image-gen-config-query.atoms"; -import { updateLLMPreferencesMutationAtom } from "@/atoms/new-llm-config/new-llm-config-mutation.atoms"; -import { - globalNewLLMConfigsAtom, - llmPreferencesAtom, - newLLMConfigsAtom, -} from "@/atoms/new-llm-config/new-llm-config-query.atoms"; -import { activeSearchSpaceIdAtom } from "@/atoms/search-spaces/search-space-query.atoms"; -import { - globalVisionLLMConfigsAtom, - visionLLMConfigsAtom, -} from "@/atoms/vision-llm-config/vision-llm-config-query.atoms"; + globalModelConnectionsAtom, + modelConnectionsAtom, + modelRolesAtom, +} from "@/atoms/model-connections/model-connections-query.atoms"; import { Badge } from "@/components/ui/badge"; import { Button } from "@/components/ui/button"; import { Drawer, DrawerContent, - DrawerHandle, DrawerHeader, DrawerTitle, DrawerTrigger, } from "@/components/ui/drawer"; +import { Input } from "@/components/ui/input"; import { Popover, PopoverContent, PopoverTrigger } from "@/components/ui/popover"; import { Spinner } from "@/components/ui/spinner"; -import { Tooltip, TooltipContent, TooltipTrigger } from "@/components/ui/tooltip"; +import type { ConnectionRead, ModelRead } from "@/contracts/types/model-connections.types"; import type { GlobalImageGenConfig, GlobalNewLLMConfig, @@ -60,272 +34,6 @@ import { useIsMobile } from "@/hooks/use-mobile"; import { getProviderIcon } from "@/lib/provider-icons"; import { cn } from "@/lib/utils"; -// ─── Helpers ──────────────────────────────────────────────────────── - -const PROVIDER_NAMES: Record = { - OPENAI: "OpenAI", - ANTHROPIC: "Anthropic", - GOOGLE: "Google", - AZURE: "Azure", - AZURE_OPENAI: "Azure OpenAI", - AWS_BEDROCK: "AWS Bedrock", - BEDROCK: "Bedrock", - DEEPSEEK: "DeepSeek", - MISTRAL: "Mistral", - COHERE: "Cohere", - GITHUB_MODELS: "GitHub Models", - GROQ: "Groq", - OLLAMA: "Ollama", - TOGETHER_AI: "Together AI", - FIREWORKS_AI: "Fireworks AI", - REPLICATE: "Replicate", - HUGGINGFACE: "HuggingFace", - PERPLEXITY: "Perplexity", - XAI: "xAI", - OPENROUTER: "OpenRouter", - CEREBRAS: "Cerebras", - SAMBANOVA: "SambaNova", - VERTEX_AI: "Vertex AI", - MINIMAX: "MiniMax", - MOONSHOT: "Moonshot", - ZHIPU: "Zhipu", - DEEPINFRA: "DeepInfra", - CLOUDFLARE: "Cloudflare", - DATABRICKS: "Databricks", - NSCALE: "NScale", - RECRAFT: "Recraft", - XINFERENCE: "XInference", - CUSTOM: "Custom", - AI21: "AI21", - ALIBABA_QWEN: "Qwen", - ANYSCALE: "Anyscale", - COMETAPI: "CometAPI", -}; - -// Provider keys valid per model type, matching backend enums -// (LiteLLMProvider, ImageGenProvider, VisionProvider in db.py) -const LLM_PROVIDER_KEYS: string[] = [ - "OPENAI", - "ANTHROPIC", - "GOOGLE", - "AZURE_OPENAI", - "BEDROCK", - "VERTEX_AI", - "GROQ", - "DEEPSEEK", - "XAI", - "MISTRAL", - "COHERE", - "OPENROUTER", - "TOGETHER_AI", - "FIREWORKS_AI", - "REPLICATE", - "PERPLEXITY", - "OLLAMA", - "CEREBRAS", - "SAMBANOVA", - "DEEPINFRA", - "AI21", - "ALIBABA_QWEN", - "MOONSHOT", - "ZHIPU", - "MINIMAX", - "HUGGINGFACE", - "CLOUDFLARE", - "DATABRICKS", - "ANYSCALE", - "COMETAPI", - "GITHUB_MODELS", - "CUSTOM", -]; - -const IMAGE_PROVIDER_KEYS: string[] = [ - "OPENAI", - "AZURE_OPENAI", - "GOOGLE", - "VERTEX_AI", - "BEDROCK", - "RECRAFT", - "OPENROUTER", - "XINFERENCE", - "NSCALE", -]; - -const VISION_PROVIDER_KEYS: string[] = [ - "OPENAI", - "ANTHROPIC", - "GOOGLE", - "AZURE_OPENAI", - "VERTEX_AI", - "BEDROCK", - "XAI", - "OPENROUTER", - "OLLAMA", - "GROQ", - "TOGETHER_AI", - "FIREWORKS_AI", - "DEEPSEEK", - "MISTRAL", - "CUSTOM", -]; - -const PROVIDER_KEYS_BY_TAB: Record = { - llm: LLM_PROVIDER_KEYS, - image: IMAGE_PROVIDER_KEYS, - vision: VISION_PROVIDER_KEYS, -}; - -function formatProviderName(provider: string): string { - const key = provider.toUpperCase(); - return ( - PROVIDER_NAMES[key] ?? - provider.charAt(0).toUpperCase() + provider.slice(1).toLowerCase().replace(/_/g, " ") - ); -} - -function normalizeText(input: string): string { - return input - .normalize("NFD") - .replace(/\p{Diacritic}/gu, "") - .toLowerCase() - .replace(/[^a-z0-9]+/g, " ") - .trim(); -} - -interface ConfigBase { - id: number; - name: string; - model_name: string; - provider: string; -} - -function filterAndScore( - configs: T[], - selectedProvider: string, - searchQuery: string -): T[] { - let result = configs; - - if (selectedProvider !== "all") { - result = result.filter((c) => c.provider.toUpperCase() === selectedProvider); - } - - if (!searchQuery.trim()) return result; - - const normalized = normalizeText(searchQuery); - const tokens = normalized.split(/\s+/).filter(Boolean); - - const scored = result.map((c) => { - const aggregate = normalizeText([c.name, c.model_name, c.provider].join(" ")); - let score = 0; - if (aggregate.includes(normalized)) score += 5; - for (const token of tokens) { - if (aggregate.includes(token)) score += 1; - } - return { config: c, score }; - }); - - return scored - .filter((s) => s.score > 0) - .sort((a, b) => b.score - a.score) - .map((s) => s.config); -} - -interface DisplayItem { - config: ConfigBase & Record; - isGlobal: boolean; - isAutoMode: boolean; -} - -const TruncatedNameWithTooltip: React.FC<{ - text: string; - className?: string; - enableTooltip: boolean; -}> = ({ text, className, enableTooltip }) => { - const textRef = useRef(null); - const openTimerRef = useRef(undefined); - const [isTruncated, setIsTruncated] = useState(false); - const [open, setOpen] = useState(false); - - const recalcTruncation = useCallback(() => { - const el = textRef.current; - if (!el) return; - setIsTruncated(el.scrollWidth > el.clientWidth + 1); - }, []); - - useEffect(() => { - if (!enableTooltip) return; - const el = textRef.current; - if (!el) return; - - const raf = requestAnimationFrame(recalcTruncation); - recalcTruncation(); - - const observer = new ResizeObserver(recalcTruncation); - observer.observe(el); - if (el.parentElement) observer.observe(el.parentElement); - window.addEventListener("resize", recalcTruncation); - - return () => { - cancelAnimationFrame(raf); - observer.disconnect(); - window.removeEventListener("resize", recalcTruncation); - }; - }, [enableTooltip, recalcTruncation]); - - useEffect(() => { - // Recompute when row text changes. - void text; - requestAnimationFrame(recalcTruncation); - }, [text, recalcTruncation]); - - useEffect( - () => () => { - if (openTimerRef.current) window.clearTimeout(openTimerRef.current); - }, - [] - ); - - if (!enableTooltip) { - return ( - - {text} - - ); - } - - const handleOpenChange = (nextOpen: boolean) => { - if (openTimerRef.current) { - window.clearTimeout(openTimerRef.current); - openTimerRef.current = undefined; - } - if (!nextOpen) { - setOpen(false); - return; - } - if (!isTruncated) return; - openTimerRef.current = window.setTimeout(() => { - setOpen(true); - openTimerRef.current = undefined; - }, 220); - }; - - return ( - - - - {text} - - - - {text} - - - ); -}; - -// ─── Component ────────────────────────────────────────────────────── - interface ModelSelectorProps { onEditLLM: (config: NewLLMConfigPublic | GlobalNewLLMConfig, isGlobal: boolean) => void; onAddNewLLM: (provider?: string) => void; @@ -336,1113 +44,207 @@ interface ModelSelectorProps { className?: string; } +type ChatModel = ModelRead & { + connectionId: number; + connectionLabel: string; + provider: string; +}; + +function modelName(model: ModelRead) { + return model.display_name || model.model_id; +} + +function connectionLabel(connection: ConnectionRead) { + if (connection.scope === "GLOBAL") return "Hosted"; + return connection.native_provider || connection.protocol; +} + +function flattenChatModels(connections: ConnectionRead[]) { + return connections.flatMap((connection) => + connection.models + .filter((model) => model.enabled && Boolean(model.capabilities?.chat)) + .map((model) => ({ + ...model, + connectionId: connection.id, + connectionLabel: connectionLabel(connection), + provider: connection.native_provider || connection.protocol, + })) + ); +} + +function groupedModels(models: ChatModel[]) { + return models.reduce>((groups, model) => { + const key = model.connectionLabel; + if (!groups[key]) groups[key] = []; + groups[key].push(model); + return groups; + }, {}); +} + export function ModelSelector({ - onEditLLM, onAddNewLLM, + onEditLLM, onEditImage, onAddNewImage, onEditVision, onAddNewVision, className, }: ModelSelectorProps) { - const [open, setOpen] = useState(false); - const [activeTab, setActiveTab] = useState<"llm" | "image" | "vision">("llm"); - const [searchQuery, setSearchQuery] = useState(""); - const [selectedProvider, setSelectedProvider] = useState("all"); - const [focusedIndex, setFocusedIndex] = useState(-1); - const [modelScrollPos, setModelScrollPos] = useState<"top" | "middle" | "bottom">("top"); - const [sidebarScrollPos, setSidebarScrollPos] = useState<"top" | "middle" | "bottom">("top"); - const providerSidebarRef = useRef(null); - const modelListRef = useRef(null); - const searchInputRef = useRef(null); + void onEditLLM; + void onEditImage; + void onAddNewImage; + void onEditVision; + void onAddNewVision; + const isMobile = useIsMobile(); - - const handleOpenChange = useCallback( - (next: boolean) => { - if (next) { - setSearchQuery(""); - setSelectedProvider("all"); - if (!isMobile) { - requestAnimationFrame(() => searchInputRef.current?.focus()); - } - } - setOpen(next); - }, - [isMobile] + const [open, setOpen] = useState(false); + const [search, setSearch] = useState(""); + const [{ data: globalConnections = [], isLoading: globalLoading }] = useAtom( + globalModelConnectionsAtom ); + const [{ data: connections = [], isLoading: connectionsLoading }] = useAtom(modelConnectionsAtom); + const [{ data: roles }] = useAtom(modelRolesAtom); + const updateRoles = useAtomValue(updateModelRolesMutationAtom); - const handleTabChange = useCallback( - (next: "llm" | "image" | "vision") => { - setActiveTab(next); - setSelectedProvider("all"); - setSearchQuery(""); - setFocusedIndex(-1); - setModelScrollPos("top"); - if (open && !isMobile) { - requestAnimationFrame(() => searchInputRef.current?.focus()); - } - }, - [open, isMobile] - ); - - const handleModelListScroll = useCallback((e: React.UIEvent) => { - const el = e.currentTarget; - const atTop = el.scrollTop <= 2; - const atBottom = el.scrollHeight - el.scrollTop - el.clientHeight <= 2; - setModelScrollPos(atTop ? "top" : atBottom ? "bottom" : "middle"); - }, []); - - const handleSidebarScroll = useCallback( - (e: React.UIEvent) => { - const el = e.currentTarget; - if (isMobile) { - const atStart = el.scrollLeft <= 2; - const atEnd = el.scrollWidth - el.scrollLeft - el.clientWidth <= 2; - setSidebarScrollPos(atStart ? "top" : atEnd ? "bottom" : "middle"); - } else { - const atTop = el.scrollTop <= 2; - const atBottom = el.scrollHeight - el.scrollTop - el.clientHeight <= 2; - setSidebarScrollPos(atTop ? "top" : atBottom ? "bottom" : "middle"); - } - }, - [isMobile] - ); - - const scrollProviderSidebar = useCallback( - (direction: "backward" | "forward") => { - const el = providerSidebarRef.current; - if (!el) return; - const delta = isMobile - ? Math.max(56, Math.floor(el.clientWidth * 0.5)) - : Math.max(44, Math.floor(el.clientHeight * 0.4)); - - if (isMobile) { - el.scrollBy({ - left: direction === "backward" ? -delta : delta, - behavior: "smooth", - }); - return; - } - - el.scrollBy({ - top: direction === "backward" ? -delta : delta, - behavior: "smooth", - }); - }, - [isMobile] - ); - - // Cmd/Ctrl+M shortcut (desktop only) - useEffect(() => { - if (isMobile) return; - const handler = (e: KeyboardEvent) => { - if ((e.metaKey || e.ctrlKey) && e.key === "m") { - e.preventDefault(); - // setOpen((prev) => !prev); - handleOpenChange(!open); - } - }; - document.addEventListener("keydown", handler); - return () => document.removeEventListener("keydown", handler); - }, [isMobile, open, handleOpenChange]); - - // ─── Data ─── - const { data: llmUserConfigs, isLoading: llmUserLoading } = useAtomValue(newLLMConfigsAtom); - const { data: llmGlobalConfigs, isLoading: llmGlobalLoading } = - useAtomValue(globalNewLLMConfigsAtom); - const { data: preferences, isLoading: prefsLoading } = useAtomValue(llmPreferencesAtom); - const searchSpaceId = useAtomValue(activeSearchSpaceIdAtom); - const { mutateAsync: updatePreferences } = useAtomValue(updateLLMPreferencesMutationAtom); - const { data: imageGlobalConfigs, isLoading: imageGlobalLoading } = - useAtomValue(globalImageGenConfigsAtom); - const { data: imageUserConfigs, isLoading: imageUserLoading } = useAtomValue(imageGenConfigsAtom); - const { data: visionGlobalConfigs, isLoading: visionGlobalLoading } = useAtomValue( - globalVisionLLMConfigsAtom - ); - const { data: visionUserConfigs, isLoading: visionUserLoading } = - useAtomValue(visionLLMConfigsAtom); - - // Pending image attachments on the composer. Used to surface an - // amber "No image" hint on chat models the catalog reports as - // non-vision (`supports_image_input=false`) when the next message - // will carry an image. The hint is purely advisory: selection, - // focus, and click handling are unaffected. The backend's safety - // net (`is_known_text_only_chat_model`) is the actual block, and - // it only fires when LiteLLM *explicitly* marks a model as - // text-only — so a model that's secretly capable but hasn't been - // annotated will still flow through to the provider. - const pendingUserImageUrls = useAtomValue(pendingUserImageDataUrlsAtom); - const hasPendingImages = pendingUserImageUrls.length > 0; - - const isLoading = - llmUserLoading || - llmGlobalLoading || - prefsLoading || - imageGlobalLoading || - imageUserLoading || - visionGlobalLoading || - visionUserLoading; - - // ─── Current selected configs ─── - const currentLLMConfig = useMemo(() => { - if (!preferences) return null; - const id = preferences.agent_llm_id; - if (id === null || id === undefined) return null; - if (id <= 0) return llmGlobalConfigs?.find((c) => c.id === id) ?? null; - return llmUserConfigs?.find((c) => c.id === id) ?? null; - }, [preferences, llmGlobalConfigs, llmUserConfigs]); - - const isLLMAutoMode = - currentLLMConfig && "is_auto_mode" in currentLLMConfig && currentLLMConfig.is_auto_mode; - - const currentImageConfig = useMemo(() => { - if (!preferences) return null; - const id = preferences.image_generation_config_id; - if (id === null || id === undefined) return null; - return ( - imageGlobalConfigs?.find((c) => c.id === id) ?? - imageUserConfigs?.find((c) => c.id === id) ?? - null + const chatModels = useMemo(() => { + const normalized = search.trim().toLowerCase(); + const models = flattenChatModels([...globalConnections, ...connections]); + if (!normalized) return models; + return models.filter((model) => + [modelName(model), model.model_id, model.connectionLabel] + .join(" ") + .toLowerCase() + .includes(normalized) ); - }, [preferences, imageGlobalConfigs, imageUserConfigs]); + }, [globalConnections, connections, search]); - const isImageAutoMode = - currentImageConfig && "is_auto_mode" in currentImageConfig && currentImageConfig.is_auto_mode; + const selected = chatModels.find((model) => model.id === roles?.chat_model_id); + const groups = groupedModels(chatModels); + const loading = globalLoading || connectionsLoading; - const currentVisionConfig = useMemo(() => { - if (!preferences) return null; - const id = preferences.vision_llm_config_id; - if (id === null || id === undefined) return null; - return ( - visionGlobalConfigs?.find((c) => c.id === id) ?? - visionUserConfigs?.find((c) => c.id === id) ?? - null - ); - }, [preferences, visionGlobalConfigs, visionUserConfigs]); + function selectModel(modelId: number) { + updateRoles.mutate({ chat_model_id: modelId }); + setOpen(false); + } - const isVisionAutoMode = - currentVisionConfig && - "is_auto_mode" in currentVisionConfig && - currentVisionConfig.is_auto_mode; - - // ─── Filtered configs (separate global / user for section headers) ─── - const filteredLLMGlobal = useMemo( - () => filterAndScore(llmGlobalConfigs ?? [], selectedProvider, searchQuery), - [llmGlobalConfigs, selectedProvider, searchQuery] - ); - const filteredLLMUser = useMemo( - () => filterAndScore(llmUserConfigs ?? [], selectedProvider, searchQuery), - [llmUserConfigs, selectedProvider, searchQuery] - ); - const filteredImageGlobal = useMemo( - () => filterAndScore(imageGlobalConfigs ?? [], selectedProvider, searchQuery), - [imageGlobalConfigs, selectedProvider, searchQuery] - ); - const filteredImageUser = useMemo( - () => filterAndScore(imageUserConfigs ?? [], selectedProvider, searchQuery), - [imageUserConfigs, selectedProvider, searchQuery] - ); - const filteredVisionGlobal = useMemo( - () => filterAndScore(visionGlobalConfigs ?? [], selectedProvider, searchQuery), - [visionGlobalConfigs, selectedProvider, searchQuery] - ); - const filteredVisionUser = useMemo( - () => filterAndScore(visionUserConfigs ?? [], selectedProvider, searchQuery), - [visionUserConfigs, selectedProvider, searchQuery] - ); - - // Combined display list for keyboard navigation - const currentDisplayItems: DisplayItem[] = useMemo(() => { - const toItems = (configs: ConfigBase[], isGlobal: boolean): DisplayItem[] => - configs.map((c) => ({ - config: c as ConfigBase & Record, - isGlobal, - isAutoMode: - isGlobal && "is_auto_mode" in c && !!(c as Record).is_auto_mode, - })); - - const sortGlobalItems = (items: DisplayItem[]): DisplayItem[] => - [...items].sort((a, b) => { - if (a.isAutoMode !== b.isAutoMode) return a.isAutoMode ? -1 : 1; - const aPremium = !!(a.config as Record).is_premium; - const bPremium = !!(b.config as Record).is_premium; - if (aPremium !== bPremium) return aPremium ? 1 : -1; - return 0; - }); - - switch (activeTab) { - case "llm": - return [ - ...sortGlobalItems(toItems(filteredLLMGlobal, true)), - ...toItems(filteredLLMUser, false), - ]; - case "image": - return [ - ...sortGlobalItems(toItems(filteredImageGlobal, true)), - ...toItems(filteredImageUser, false), - ]; - case "vision": - return [ - ...sortGlobalItems(toItems(filteredVisionGlobal, true)), - ...toItems(filteredVisionUser, false), - ]; - } - }, [ - activeTab, - filteredLLMGlobal, - filteredLLMUser, - filteredImageGlobal, - filteredImageUser, - filteredVisionGlobal, - filteredVisionUser, - ]); - - // ─── Provider sidebar data ─── - // Collect which providers actually have configured models for the active tab - const configuredProviderSet = useMemo(() => { - const configs = - activeTab === "llm" - ? [...(llmGlobalConfigs ?? []), ...(llmUserConfigs ?? [])] - : activeTab === "image" - ? [...(imageGlobalConfigs ?? []), ...(imageUserConfigs ?? [])] - : [...(visionGlobalConfigs ?? []), ...(visionUserConfigs ?? [])]; - const set = new Set(); - for (const c of configs) { - if (c.provider) set.add(c.provider.toUpperCase()); - } - return set; - }, [ - activeTab, - llmGlobalConfigs, - llmUserConfigs, - imageGlobalConfigs, - imageUserConfigs, - visionGlobalConfigs, - visionUserConfigs, - ]); - - // Show only providers valid for the active tab; configured ones first - const activeProviders = useMemo(() => { - const tabKeys = PROVIDER_KEYS_BY_TAB[activeTab] ?? LLM_PROVIDER_KEYS; - const configured = tabKeys.filter((p) => configuredProviderSet.has(p)); - const unconfigured = tabKeys.filter((p) => !configuredProviderSet.has(p)); - return ["all", ...configured, ...unconfigured]; - }, [activeTab, configuredProviderSet]); - - const providerModelCounts = useMemo(() => { - const allConfigs = - activeTab === "llm" - ? [...(llmGlobalConfigs ?? []), ...(llmUserConfigs ?? [])] - : activeTab === "image" - ? [...(imageGlobalConfigs ?? []), ...(imageUserConfigs ?? [])] - : [...(visionGlobalConfigs ?? []), ...(visionUserConfigs ?? [])]; - const counts: Record = { all: allConfigs.length }; - for (const c of allConfigs) { - const p = c.provider.toUpperCase(); - counts[p] = (counts[p] || 0) + 1; - } - return counts; - }, [ - activeTab, - llmGlobalConfigs, - llmUserConfigs, - imageGlobalConfigs, - imageUserConfigs, - visionGlobalConfigs, - visionUserConfigs, - ]); - - // ─── Selection handlers ─── - const handleSelectLLM = useCallback( - async (config: NewLLMConfigPublic | GlobalNewLLMConfig) => { - if (currentLLMConfig?.id === config.id) { - setOpen(false); - return; - } - if (!searchSpaceId) { - toast.error("No search space selected"); - return; - } - try { - await updatePreferences({ - search_space_id: Number(searchSpaceId), - data: { agent_llm_id: config.id }, - }); - toast.success(`Switched to ${config.name}`); - setOpen(false); - } catch { - toast.error("Failed to switch model"); - } - }, - [currentLLMConfig, searchSpaceId, updatePreferences] - ); - - const handleSelectImage = useCallback( - async (configId: number) => { - if (currentImageConfig?.id === configId) { - setOpen(false); - return; - } - if (!searchSpaceId) { - toast.error("No search space selected"); - return; - } - try { - await updatePreferences({ - search_space_id: Number(searchSpaceId), - data: { image_generation_config_id: configId }, - }); - toast.success("Image model updated"); - setOpen(false); - } catch { - toast.error("Failed to switch image model"); - } - }, - [currentImageConfig, searchSpaceId, updatePreferences] - ); - - const handleSelectVision = useCallback( - async (configId: number) => { - if (currentVisionConfig?.id === configId) { - setOpen(false); - return; - } - if (!searchSpaceId) { - toast.error("No search space selected"); - return; - } - try { - await updatePreferences({ - search_space_id: Number(searchSpaceId), - data: { vision_llm_config_id: configId }, - }); - toast.success("Vision model updated"); - setOpen(false); - } catch { - toast.error("Failed to switch vision model"); - } - }, - [currentVisionConfig, searchSpaceId, updatePreferences] - ); - - const handleSelectItem = useCallback( - (item: DisplayItem) => { - switch (activeTab) { - case "llm": - handleSelectLLM(item.config as NewLLMConfigPublic | GlobalNewLLMConfig); - break; - case "image": - handleSelectImage(item.config.id); - break; - case "vision": - handleSelectVision(item.config.id); - break; - } - }, - [activeTab, handleSelectLLM, handleSelectImage, handleSelectVision] - ); - - const handleEditItem = useCallback( - (e: React.MouseEvent, item: DisplayItem) => { - e.stopPropagation(); - setOpen(false); - switch (activeTab) { - case "llm": - onEditLLM(item.config as NewLLMConfigPublic | GlobalNewLLMConfig, item.isGlobal); - break; - case "image": - onEditImage?.(item.config as ImageGenerationConfig | GlobalImageGenConfig, item.isGlobal); - break; - case "vision": - onEditVision?.(item.config as VisionLLMConfig | GlobalVisionLLMConfig, item.isGlobal); - break; - } - }, - [activeTab, onEditLLM, onEditImage, onEditVision] - ); - - // ─── Keyboard navigation ─── - // biome-ignore lint/correctness/useExhaustiveDependencies: searchQuery and selectedProvider are intentional triggers to reset focus - useEffect(() => { - setFocusedIndex(-1); - }, [searchQuery, selectedProvider]); - - useEffect(() => { - if (focusedIndex < 0 || !modelListRef.current) return; - const items = modelListRef.current.querySelectorAll("[data-model-index]"); - items[focusedIndex]?.scrollIntoView({ - block: "nearest", - behavior: "smooth", - }); - }, [focusedIndex]); - - const handleKeyDown = useCallback( - (e: React.KeyboardEvent) => { - const count = currentDisplayItems.length; - - // Arrow Left/Right cycle provider filters - if (e.key === "ArrowLeft" || e.key === "ArrowRight") { - e.preventDefault(); - const providers = activeProviders; - const idx = providers.indexOf(selectedProvider); - let next: number; - if (e.key === "ArrowLeft") { - next = idx > 0 ? idx - 1 : providers.length - 1; - } else { - next = idx < providers.length - 1 ? idx + 1 : 0; - } - setSelectedProvider(providers[next]); - if (providerSidebarRef.current) { - const buttons = providerSidebarRef.current.querySelectorAll("button"); - buttons[next]?.scrollIntoView({ - block: "nearest", - inline: "nearest", - behavior: "smooth", - }); - } - return; - } - - if (count === 0) return; - - switch (e.key) { - case "ArrowDown": - e.preventDefault(); - setFocusedIndex((prev) => (prev < count - 1 ? prev + 1 : 0)); - break; - case "ArrowUp": - e.preventDefault(); - setFocusedIndex((prev) => (prev > 0 ? prev - 1 : count - 1)); - break; - case "Enter": - e.preventDefault(); - if (focusedIndex >= 0 && focusedIndex < count) { - handleSelectItem(currentDisplayItems[focusedIndex]); - } - break; - case "Home": - e.preventDefault(); - setFocusedIndex(0); - break; - case "End": - e.preventDefault(); - setFocusedIndex(count - 1); - break; - } - }, - [currentDisplayItems, focusedIndex, activeProviders, selectedProvider, handleSelectItem] - ); - - // ─── Render: Provider sidebar ─── - const renderProviderSidebar = () => { - const configuredCount = configuredProviderSet.size; - - return ( -
- {!isMobile && ( -
- -
- )} - {isMobile && ( -
- -
- )} -
+
+
+ + setSearch(event.target.value)} + placeholder="Search chat models..." + className="pl-9" + /> +
+
+
+ - - - {isAll ? "All Models" : formatProviderName(provider)} - {isConfigured ? ` (${count})` : " (not configured)"} - - - - ); - })} -
- {!isMobile && ( -
- -
- )} - {isMobile && ( -
- -
- )} -
- ); - }; - - // ─── Render: Model card ─── - const getSelectedId = () => { - switch (activeTab) { - case "llm": - return currentLLMConfig?.id; - case "image": - return currentImageConfig?.id; - case "vision": - return currentVisionConfig?.id; - } - }; - - const renderModelCard = (item: DisplayItem, index: number) => { - const { config, isAutoMode } = item; - const isSelected = getSelectedId() === config.id; - const isFocused = focusedIndex === index; - const hasCitations = "citations_enabled" in config && !!config.citations_enabled; - const hasPremiumStatus = "is_premium" in config && !isAutoMode; - const isPremium = hasPremiumStatus && !!(config as Record).is_premium; - // Chat-tab only: surface an amber "No image" hint when the - // composer carries images and the catalog reports the model as - // non-vision. This is purely advisory — selection is *not* - // blocked. The backend's narrow safety net - // (`is_known_text_only_chat_model`) is the source of truth for - // rejecting image turns, and it only fires when LiteLLM - // explicitly marks the model as text-only. A model surfaced as - // `supports_image_input=false` here may still be capable in - // practice (unknown / unmapped LiteLLM entry), so we let the - // user pick it and the provider response decide. - const isImageIncompatibleChatModel = - activeTab === "llm" && - hasPendingImages && - "supports_image_input" in config && - (config as Record).supports_image_input === false; - - return ( -
handleSelectItem(item)} - onKeyDown={ - isMobile - ? undefined - : (e) => { - if (e.key === "Enter" || e.key === " ") { - e.preventDefault(); - handleSelectItem(item); - } - } - } - onMouseEnter={() => setFocusedIndex(index)} - className={cn( - "group flex items-center gap-2.5 px-3 py-2 rounded-xl", - "transition-colors duration-150 mx-2 cursor-pointer", - "hover:bg-accent hover:text-accent-foreground", - isFocused && "bg-accent text-accent-foreground", - isSelected && "bg-accent text-accent-foreground" - )} - > - {/* Provider icon */} -
- {getProviderIcon(config.provider as string, { - isAutoMode, - className: "size-5", - })} -
- - {/* Model info */} -
-
- - {isAutoMode && ( - - Recommended - - )} - {isImageIncompatibleChatModel && ( - - No image - - )} -
- {isAutoMode ? ( -
- Auto Mode +
+
+
- ) : ( - (hasPremiumStatus || hasCitations) && ( -
- {hasPremiumStatus && ( - - {isPremium ? "Premium" : "Free"} - - )} - {hasCitations && ( - - Citations - - )} +
+
Auto
+
Use the hosted/global router
+
+
+ {(roles?.chat_model_id ?? 0) === 0 ? : null} + + {loading ? ( +
+ +
+ ) : Object.keys(groups).length === 0 ? ( +
+ No enabled chat models. Add or enable models in Settings. +
+ ) : ( + Object.entries(groups).map(([connection, models]) => ( +
+
+ {connection}
- ) - )} -
- - {/* Actions */} -
- {!isAutoMode && ( - - )} - {isSelected && ( -
- -
- )} -
-
- ); - }; - - // ─── Render: Full content ─── - const renderContent = () => { - const globalItems = currentDisplayItems.filter((i) => i.isGlobal); - const userItems = currentDisplayItems.filter((i) => !i.isGlobal); - const globalStartIdx = 0; - const userStartIdx = globalItems.length; - - const addHandler = - activeTab === "llm" ? onAddNewLLM : activeTab === "image" ? onAddNewImage : onAddNewVision; - const addLabel = - activeTab === "llm" - ? "Add Model" - : activeTab === "image" - ? "Add Image Model" - : "Add Vision Model"; - - return ( -
- {/* Tab header */} -
-
- {( - [ - { - value: "llm" as const, - icon: Zap, - label: "LLM", - }, - { - value: "image" as const, - icon: ImageIcon, - label: "Image", - }, - { - value: "vision" as const, - icon: ScanEye, - label: "Vision", - }, - ] as const - ).map(({ value, icon: Icon, label }) => ( - - ))} -
-
- - {/* Two-pane layout */} -
- {/* Provider sidebar */} - {renderProviderSidebar()} - - {/* Main content */} -
- {/* Search */} -
- - setSearchQuery(e.target.value)} - onKeyDown={isMobile ? undefined : handleKeyDown} - role="combobox" - aria-expanded={true} - aria-controls="model-selector-list" - className={cn( - "w-full pl-8 pr-3 py-2.5 text-sm bg-transparent", - "focus:outline-none", - "placeholder:text-muted-foreground" - )} - /> -
- - {/* Provider header when filtered */} - {selectedProvider !== "all" && ( -
- {getProviderIcon(selectedProvider, { - className: "size-4", - })} - {formatProviderName(selectedProvider)} - - {configuredProviderSet.has(selectedProvider) - ? `${providerModelCounts[selectedProvider] || 0} models` - : "Not configured"} - -
- )} - - {/* Model list */} -
- {currentDisplayItems.length === 0 ? ( -
- {selectedProvider !== "all" && !configuredProviderSet.has(selectedProvider) ? ( - <> -
- {getProviderIcon(selectedProvider, { - className: "size-10", - })} -
-

- No {formatProviderName(selectedProvider)} models configured -

-

- Add a model with this provider to get started -

- {addHandler && ( - - )} - - ) : searchQuery ? ( - <> - -

No models found

-

- Try a different search term -

- - ) : ( - <> -

- No models configured -

-

- Configure models in your search space settings -

- - )} -
- ) : ( - <> - {globalItems.length > 0 && ( - <> -
- Global Models -
- {globalItems.map((item, i) => renderModelCard(item, globalStartIdx + i))} - - )} - {globalItems.length > 0 && userItems.length > 0 && ( -
- )} - {userItems.length > 0 && ( - <> -
- Your Configurations -
- {userItems.map((item, i) => renderModelCard(item, userStartIdx + i))} - - )} - - )} -
- - {/* Add model button */} - {addHandler && ( -
- -
- )} -
-
+
+
+ {getProviderIcon(model.provider, { className: "size-4 shrink-0" })} + {modelName(model)} +
+
{model.model_id}
+
+
+ {!model.capabilities?.vision ? ( + + No image + + ) : null} + {roles?.chat_model_id === model.id ? : null} +
+ + ))} +
+ )) + )}
- ); - }; +
+ +
+
+ ); - // ─── Trigger button ─── - const triggerButton = ( + const trigger = ( ); - // ─── Shell: Drawer on mobile, Popover on desktop ─── if (isMobile) { return ( - - {triggerButton} - - - - Select Model + + {trigger} + + + Select Chat Model -
{renderContent()}
+ {content}
); } return ( - - {triggerButton} - e.preventDefault()} - > - {renderContent()} + + {trigger} + + {content} ); diff --git a/surfsense_web/content/docs/how-to/ollama.mdx b/surfsense_web/content/docs/how-to/ollama.mdx index 48b231705..f5ec09e1b 100644 --- a/surfsense_web/content/docs/how-to/ollama.mdx +++ b/surfsense_web/content/docs/how-to/ollama.mdx @@ -22,7 +22,7 @@ If SurfSense runs in Docker, do not use `localhost` unless Ollama is in the same ## 2) Add Ollama in SurfSense -Go to **Search Space Settings -> Agent Models -> Add Model** and set: +Go to **Search Space Settings -> Models -> Add Model** and set: - Provider: `OLLAMA` - Model name: your model tag, for example `llama3.2` or `qwen3:8b` From 6c352021a093bd6b3b69307b093a934c83d9255b Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Wed, 10 Jun 2026 21:50:22 +0530 Subject: [PATCH 011/212] chore(i18n): add model connection copy --- surfsense_web/messages/en.json | 3 ++- surfsense_web/messages/es.json | 2 +- surfsense_web/messages/hi.json | 2 +- surfsense_web/messages/pt.json | 2 +- surfsense_web/messages/zh.json | 2 +- 5 files changed, 6 insertions(+), 5 deletions(-) diff --git a/surfsense_web/messages/en.json b/surfsense_web/messages/en.json index a13942e64..80badd84f 100644 --- a/surfsense_web/messages/en.json +++ b/surfsense_web/messages/en.json @@ -743,7 +743,8 @@ "back_to_app": "Back to app", "nav_general": "General", "nav_general_desc": "Name, description & basic info", - "nav_agent_models": "Agent Models", + "nav_models": "Models", + "nav_agent_models": "Chat Models", "nav_agent_models_desc": "Models with prompts & citations", "nav_role_assignments": "Role Assignments", "nav_role_assignments_desc": "Assign configs to agent roles", diff --git a/surfsense_web/messages/es.json b/surfsense_web/messages/es.json index 33ae79c52..d6225f532 100644 --- a/surfsense_web/messages/es.json +++ b/surfsense_web/messages/es.json @@ -743,7 +743,7 @@ "back_to_app": "Volver a la app", "nav_general": "General", "nav_general_desc": "Nombre, descripción e información básica", - "nav_agent_models": "Modelos de agente", + "nav_agent_models": "Modelos de chat", "nav_agent_models_desc": "Modelos LLM con prompts y citas", "nav_role_assignments": "Asignaciones de roles", "nav_role_assignments_desc": "Asignar configuraciones a roles de agente", diff --git a/surfsense_web/messages/hi.json b/surfsense_web/messages/hi.json index 7a26d0c1d..3cb3ad41a 100644 --- a/surfsense_web/messages/hi.json +++ b/surfsense_web/messages/hi.json @@ -743,7 +743,7 @@ "back_to_app": "ऐप पर वापस जाएं", "nav_general": "सामान्य", "nav_general_desc": "नाम, विवरण और बुनियादी जानकारी", - "nav_agent_models": "एजेंट मॉडल", + "nav_agent_models": "चैट मॉडल", "nav_agent_models_desc": "प्रॉम्प्ट और उद्धरण के साथ LLM मॉडल", "nav_role_assignments": "भूमिका असाइनमेंट", "nav_role_assignments_desc": "एजेंट भूमिकाओं को कॉन्फ़िगरेशन असाइन करें", diff --git a/surfsense_web/messages/pt.json b/surfsense_web/messages/pt.json index 61c22e086..96bfc096d 100644 --- a/surfsense_web/messages/pt.json +++ b/surfsense_web/messages/pt.json @@ -743,7 +743,7 @@ "back_to_app": "Voltar ao app", "nav_general": "Geral", "nav_general_desc": "Nome, descrição e informações básicas", - "nav_agent_models": "Modelos do agente", + "nav_agent_models": "Modelos de chat", "nav_agent_models_desc": "Modelos LLM com prompts e citações", "nav_role_assignments": "Atribuições de funções", "nav_role_assignments_desc": "Atribuir configurações a funções do agente", diff --git a/surfsense_web/messages/zh.json b/surfsense_web/messages/zh.json index 7d0419cbd..03b217f12 100644 --- a/surfsense_web/messages/zh.json +++ b/surfsense_web/messages/zh.json @@ -727,7 +727,7 @@ "back_to_app": "返回应用", "nav_general": "常规", "nav_general_desc": "名称、描述和基本信息", - "nav_agent_models": "代理模型", + "nav_agent_models": "聊天模型", "nav_agent_models_desc": "LLM 模型配置提示词和引用", "nav_role_assignments": "角色分配", "nav_role_assignments_desc": "为代理角色分配配置", From 85114d2a0e0ed04c67aee49fba12d341cad02322 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Wed, 10 Jun 2026 21:50:42 +0530 Subject: [PATCH 012/212] refactor(chat): rename image generation config parameters for clarity --- .../multi_agent_chat/main_agent/runtime/agent_cache.py | 4 ++-- .../chat/multi_agent_chat/main_agent/runtime/factory.py | 8 ++++---- surfsense_backend/app/agents/chat/runtime/llm_config.py | 4 ++-- surfsense_backend/app/agents/podcaster/nodes.py | 2 +- 4 files changed, 9 insertions(+), 9 deletions(-) diff --git a/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/runtime/agent_cache.py b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/runtime/agent_cache.py index 6ac22e575..2d3599de0 100644 --- a/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/runtime/agent_cache.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/runtime/agent_cache.py @@ -57,7 +57,7 @@ async def build_agent_with_cache( mcp_tools_by_agent: dict[str, list[BaseTool]], disabled_tools: list[str] | None, config_id: str | None, - image_generation_config_id_override: int | None = None, + image_gen_model_id_override: int | None = None, ) -> Any: """Compile the multi-agent graph, serving from cache when key components are stable.""" @@ -121,7 +121,7 @@ async def build_agent_with_cache( # Bound into the generate_image subagent tool at construction time, so it # must key the compiled-agent cache to avoid leaking one automation's # image model into another with the same config_id/search_space. - image_generation_config_id_override, + image_gen_model_id_override, ) return await get_cache().get_or_build(cache_key, builder=_build) diff --git a/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/runtime/factory.py b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/runtime/factory.py index adb1bc1ed..10a734192 100644 --- a/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/runtime/factory.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/runtime/factory.py @@ -72,11 +72,11 @@ async def create_multi_agent_chat_deep_agent( mentioned_document_ids: list[int] | None = None, anon_session_id: str | None = None, filesystem_selection: FilesystemSelection | None = None, - image_generation_config_id: int | None = None, + image_gen_model_id: int | None = None, ): """Deep agent with SurfSense tools/middleware; registry route subagents behind ``task`` when enabled. - ``image_generation_config_id`` overrides the search space's image model for + ``image_gen_model_id`` overrides the search space's image model for this invocation (used by automations to run on their captured model). When ``None``, the ``generate_image`` tool resolves the live search-space pref. """ @@ -147,7 +147,7 @@ async def create_multi_agent_chat_deep_agent( "llm": llm, # Per-invocation image model override (automations run on their captured # model). Reaches the generate_image subagent tool via subagent_dependencies. - "image_generation_config_id_override": image_generation_config_id, + "image_gen_model_id_override": image_gen_model_id, } _t0 = time.perf_counter() @@ -303,7 +303,7 @@ async def create_multi_agent_chat_deep_agent( mcp_tools_by_agent=mcp_tools_by_agent, disabled_tools=disabled_tools, config_id=config_id, - image_generation_config_id_override=image_generation_config_id, + image_gen_model_id_override=image_gen_model_id, ) _perf_log.info( "[create_agent] Middleware stack + graph compiled in %.3fs", diff --git a/surfsense_backend/app/agents/chat/runtime/llm_config.py b/surfsense_backend/app/agents/chat/runtime/llm_config.py index aad432edb..03d7f548e 100644 --- a/surfsense_backend/app/agents/chat/runtime/llm_config.py +++ b/surfsense_backend/app/agents/chat/runtime/llm_config.py @@ -351,7 +351,7 @@ async def load_agent_llm_config_for_search_space( session: AsyncSession, search_space_id: int, ) -> "AgentConfig | None": - """Load the agent LLM config for a search space via its agent_llm_id. + """Load the chat model config for a search space via its agent_llm_id. Positive id -> DB; negative -> YAML; None -> first global config (-1). """ @@ -372,7 +372,7 @@ async def load_agent_llm_config_for_search_space( ) return await load_agent_config(session, config_id, search_space_id) except Exception as e: - print(f"Error loading agent LLM config for search space {search_space_id}: {e}") + print(f"Error loading chat model config for search space {search_space_id}: {e}") return None diff --git a/surfsense_backend/app/agents/podcaster/nodes.py b/surfsense_backend/app/agents/podcaster/nodes.py index d1f140a44..0d54cbe45 100644 --- a/surfsense_backend/app/agents/podcaster/nodes.py +++ b/surfsense_backend/app/agents/podcaster/nodes.py @@ -31,7 +31,7 @@ async def create_podcast_transcript( llm = await get_agent_llm(state.db_session, search_space_id) if not llm: - error_message = f"No agent LLM configured for search space {search_space_id}" + error_message = f"No chat model configured for search space {search_space_id}" print(error_message) raise RuntimeError(error_message) From 780e24213240310d52071c89c528ad1643d86071 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Thu, 11 Jun 2026 00:11:53 +0530 Subject: [PATCH 013/212] feat(model-connections): implement manual model addition and enhance model discovery --- .../app/routes/model_connections_routes.py | 38 +++ surfsense_backend/app/schemas/__init__.py | 1 + .../app/schemas/model_connections.py | 11 + .../app/services/model_connection_service.py | 28 +- .../model-connections-mutation.atoms.ts | 28 +- .../settings/model-connections-settings.tsx | 312 ++++++++++++------ .../types/model-connections.types.ts | 6 + .../lib/apis/model-connections-api.service.ts | 12 + 8 files changed, 335 insertions(+), 101 deletions(-) diff --git a/surfsense_backend/app/routes/model_connections_routes.py b/surfsense_backend/app/routes/model_connections_routes.py index 69910183d..6d19a5ed1 100644 --- a/surfsense_backend/app/routes/model_connections_routes.py +++ b/surfsense_backend/app/routes/model_connections_routes.py @@ -10,6 +10,7 @@ from app.db import ( Connection, ConnectionScope, Model, + ModelSource, Permission, SearchSpace, User, @@ -19,6 +20,7 @@ from app.schemas import ( ConnectionCreate, ConnectionRead, ConnectionUpdate, + ModelCreate, ModelRead, ModelRolesRead, ModelRolesUpdate, @@ -26,6 +28,7 @@ from app.schemas import ( VerifyConnectionResponse, ) from app.services.model_connection_service import ( + derive_capabilities, discover_models, persist_verification, test_model, @@ -254,6 +257,41 @@ async def discover_connection_models( return [_model_read(model) for model in conn.models] +@router.post("/model-connections/{connection_id}/models", response_model=ModelRead) +async def add_manual_model( + connection_id: int, + data: ModelCreate, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +): + conn = await _load_connection(session, connection_id) + await _assert_connection_access(session, user, conn, Permission.LLM_CONFIGS_UPDATE.value) + + model_id = data.model_id.strip() + if not model_id: + raise HTTPException(status_code=400, detail="model_id is required") + if any(existing.model_id == model_id for existing in conn.models): + raise HTTPException(status_code=400, detail="Model already exists on this connection") + + capabilities = derive_capabilities(conn, model_id) + model = Model( + connection_id=conn.id, + model_id=model_id, + display_name=data.display_name or None, + source=ModelSource.MANUAL, + capabilities=capabilities, + capabilities_declared=capabilities, + capabilities_verified={}, + capabilities_override={}, + enabled=True, + catalog={}, + ) + session.add(model) + await session.commit() + await session.refresh(model) + return _model_read(model) + + @router.put("/models/{model_id}", response_model=ModelRead) async def update_model( model_id: int, diff --git a/surfsense_backend/app/schemas/__init__.py b/surfsense_backend/app/schemas/__init__.py index c14671c99..2a06eca5c 100644 --- a/surfsense_backend/app/schemas/__init__.py +++ b/surfsense_backend/app/schemas/__init__.py @@ -48,6 +48,7 @@ from .model_connections import ( ConnectionCreate, ConnectionRead, ConnectionUpdate, + ModelCreate, ModelRead, ModelRolesRead, ModelRolesUpdate, diff --git a/surfsense_backend/app/schemas/model_connections.py b/surfsense_backend/app/schemas/model_connections.py index 731064375..ea1ec4e88 100644 --- a/surfsense_backend/app/schemas/model_connections.py +++ b/surfsense_backend/app/schemas/model_connections.py @@ -65,6 +65,17 @@ class ConnectionUpdate(BaseModel): enabled: bool | None = None +class ModelCreate(BaseModel): + """Manually register a model id on a connection. + + For providers without a usable ``/models`` endpoint (Perplexity, MiniMax, + Azure deployments, etc.) or to pin a single model from a noisy provider. + """ + + model_id: str = Field(..., max_length=255) + display_name: str | None = Field(None, max_length=255) + + class ModelUpdate(BaseModel): display_name: str | None = Field(None, max_length=255) enabled: bool | None = None diff --git a/surfsense_backend/app/services/model_connection_service.py b/surfsense_backend/app/services/model_connection_service.py index 81090acaf..c8d2e8a5a 100644 --- a/surfsense_backend/app/services/model_connection_service.py +++ b/surfsense_backend/app/services/model_connection_service.py @@ -122,6 +122,17 @@ def _litellm_capabilities(model_string: str, model_id: str) -> dict[str, bool]: return capabilities +def _allowlist(conn: Connection) -> set[str]: + """Per-connection model-id allowlist stored in ``extra.model_ids``. + + Empty/absent means "no restriction" (discover everything), mirroring + OpenWebUI's behaviour. A non-empty list restricts discovery to those ids — + essential for providers like OpenRouter that expose hundreds of models. + """ + raw = (conn.extra or {}).get("model_ids") or [] + return {str(item).strip() for item in raw if str(item).strip()} + + def derive_capabilities(conn: Connection, model_id: str, metadata: dict | None = None) -> dict[str, bool]: metadata = metadata or {} if conn.protocol == ConnectionProtocol.OLLAMA: @@ -140,13 +151,15 @@ def derive_capabilities(conn: Connection, model_id: str, metadata: dict | None = async def discover_models(conn: Connection) -> list[dict[str, Any]]: + allowlist = _allowlist(conn) + if conn.protocol == ConnectionProtocol.OLLAMA: url = f"{conn.base_url.rstrip('/')}/api/tags" async with httpx.AsyncClient(timeout=DISCOVERY_TIMEOUT_SECONDS) as client: response = await client.get(url, headers=_auth_headers(conn)) response.raise_for_status() models = response.json().get("models", []) - return [ + results = [ { "model_id": item.get("model") or item.get("name"), "display_name": item.get("name") or item.get("model"), @@ -157,14 +170,13 @@ async def discover_models(conn: Connection) -> list[dict[str, Any]]: for item in models if item.get("model") or item.get("name") ] - - if conn.protocol == ConnectionProtocol.OPENAI_COMPATIBLE: + elif conn.protocol == ConnectionProtocol.OPENAI_COMPATIBLE: url = f"{ensure_v1(conn.base_url)}/models" async with httpx.AsyncClient(timeout=DISCOVERY_TIMEOUT_SECONDS) as client: response = await client.get(url, headers=_auth_headers(conn)) response.raise_for_status() models = response.json().get("data", []) - return [ + results = [ { "model_id": item.get("id"), "display_name": item.get("name") or item.get("id"), @@ -175,9 +187,13 @@ async def discover_models(conn: Connection) -> list[dict[str, Any]]: for item in models if item.get("id") ] + else: + # Native providers rely on curated/global catalog entries or manual rows. + return [] - # Native providers rely on curated/global catalog entries or manual rows. - return [] + if allowlist: + results = [item for item in results if item["model_id"] in allowlist] + return results async def test_model(conn: Connection, model: Model) -> VerifyResult: diff --git a/surfsense_web/atoms/model-connections/model-connections-mutation.atoms.ts b/surfsense_web/atoms/model-connections/model-connections-mutation.atoms.ts index 7d58a402c..612216bf2 100644 --- a/surfsense_web/atoms/model-connections/model-connections-mutation.atoms.ts +++ b/surfsense_web/atoms/model-connections/model-connections-mutation.atoms.ts @@ -3,6 +3,7 @@ import { toast } from "sonner"; import type { ConnectionCreateRequest, ConnectionUpdateRequest, + ModelCreateRequest, ModelRoles, ModelUpdateRequest, } from "@/contracts/types/model-connections.types"; @@ -67,8 +68,17 @@ export const verifyModelConnectionMutationAtom = atomWithMutation((get) => { mutationKey: ["model-connections", "verify"], mutationFn: (id: number) => modelConnectionsApiService.verifyConnection(id), onSuccess: (result) => { - if (result.ok) toast.success("Connection verified"); - else toast.error(result.message || "Connection failed"); + if (result.ok) { + toast.success("Connection verified"); + } else { + // Non-fatal: many providers lack a /models endpoint yet still serve + // chat. Guide the user to add model IDs manually instead of alarming. + toast.warning( + result.message + ? `${result.message} Chat may still work — add model IDs manually.` + : "Couldn't list models. Chat may still work — add model IDs manually." + ); + } invalidateModelConnections(searchSpaceId); }, onError: (error: Error) => toast.error(error.message || "Failed to verify connection"), @@ -88,6 +98,20 @@ export const discoverConnectionModelsMutationAtom = atomWithMutation((get) => { }; }); +export const addManualModelMutationAtom = atomWithMutation((get) => { + const searchSpaceId = Number(get(activeSearchSpaceIdAtom)); + return { + mutationKey: ["models", "add-manual"], + mutationFn: ({ connectionId, data }: { connectionId: number; data: ModelCreateRequest }) => + modelConnectionsApiService.addManualModel(connectionId, data), + onSuccess: () => { + toast.success("Model added"); + invalidateModelConnections(searchSpaceId); + }, + onError: (error: Error) => toast.error(error.message || "Failed to add model"), + }; +}); + export const updateModelMutationAtom = atomWithMutation((get) => { const searchSpaceId = Number(get(activeSearchSpaceIdAtom)); return { diff --git a/surfsense_web/components/settings/model-connections-settings.tsx b/surfsense_web/components/settings/model-connections-settings.tsx index 5fa4cccf7..e89fc3278 100644 --- a/surfsense_web/components/settings/model-connections-settings.tsx +++ b/surfsense_web/components/settings/model-connections-settings.tsx @@ -1,12 +1,14 @@ "use client"; import { useAtom, useAtomValue } from "jotai"; -import { CheckCircle2, PlugZap, RefreshCcw, XCircle } from "lucide-react"; +import { CheckCircle2, PlugZap, Plus, RefreshCcw, XCircle } from "lucide-react"; import { useMemo, useState } from "react"; import { + addManualModelMutationAtom, createModelConnectionMutationAtom, discoverConnectionModelsMutationAtom, testModelMutationAtom, + updateModelConnectionMutationAtom, updateModelMutationAtom, updateModelRolesMutationAtom, verifyModelConnectionMutationAtom, @@ -46,9 +48,16 @@ type Preset = { }; const PRESETS: Preset[] = [ + { id: "custom", label: "OpenAI-compatible (any URL)", protocol: "OPENAI_COMPATIBLE" }, { id: "openai", label: "OpenAI", protocol: "NATIVE", nativeProvider: "OPENAI" }, { id: "anthropic", label: "Anthropic", protocol: "NATIVE", nativeProvider: "ANTHROPIC" }, - { id: "openrouter", label: "OpenRouter", protocol: "NATIVE", nativeProvider: "OPENROUTER" }, + { + id: "openrouter", + label: "OpenRouter", + protocol: "NATIVE", + nativeProvider: "OPENROUTER", + baseUrl: "https://openrouter.ai/api/v1", + }, { id: "ollama", label: "Ollama", @@ -86,6 +95,22 @@ const PRESETS: Preset[] = [ }, ]; +// Free-text URL hints (datalist), mirroring OpenWebUI. These never restrict +// what the user can type — any OpenAI-compatible endpoint works. +const URL_SUGGESTIONS = [ + "https://api.openai.com/v1", + "https://api.anthropic.com/v1", + "https://openrouter.ai/api/v1", + "https://generativelanguage.googleapis.com/v1beta/openai", + "https://api.groq.com/openai/v1", + "https://api.mistral.ai/v1", + "https://api.deepseek.com/v1", + "https://api.x.ai/v1", + "http://host.docker.internal:11434", + "http://host.docker.internal:1234/v1", + "http://host.docker.internal:8000/v1", +]; + function modelLabel(model: ModelRead) { return model.display_name || model.model_id; } @@ -123,22 +148,183 @@ function flattenModels(connections: ConnectionRead[]) { ); } +function ConnectionCard({ connection }: { connection: ConnectionRead }) { + const verifyConnection = useAtomValue(verifyModelConnectionMutationAtom); + const discoverModels = useAtomValue(discoverConnectionModelsMutationAtom); + const updateConnection = useAtomValue(updateModelConnectionMutationAtom); + const addManualModel = useAtomValue(addManualModelMutationAtom); + const updateModel = useAtomValue(updateModelMutationAtom); + const testModel = useAtomValue(testModelMutationAtom); + + const allowlist = Array.isArray(connection.extra?.model_ids) + ? (connection.extra.model_ids as string[]) + : []; + const [allowlistText, setAllowlistText] = useState(allowlist.join(", ")); + const [manualModelId, setManualModelId] = useState(""); + + const providerLabel = connection.native_provider || connection.protocol; + const isLocal = connection.protocol === "OLLAMA" || !connection.base_url?.startsWith("https"); + + function saveAllowlist() { + const ids = allowlistText + .split(",") + .map((value) => value.trim()) + .filter(Boolean); + updateConnection.mutate({ + id: connection.id, + data: { extra: { ...(connection.extra ?? {}), model_ids: ids } }, + }); + } + + function addModel() { + const modelId = manualModelId.trim(); + if (!modelId) return; + addManualModel.mutate( + { connectionId: connection.id, data: { model_id: modelId } }, + { onSuccess: () => setManualModelId("") } + ); + } + + return ( +
+
+
+
+ {getProviderIcon(providerLabel, { className: "size-4" })} + {providerLabel} +
+
+ {connection.base_url || "Provider default endpoint"} +
+
+
+ + + +
+
+ + {connection.last_status && connection.last_status !== "OK" ? ( +

+ {connection.last_error || "Could not list models."} Chat may still work — add model + IDs manually below. +

+ ) : null} + + {!isLocal ? ( +
+ +
+ setAllowlistText(event.target.value)} + placeholder="Comma-separated, e.g. anthropic/claude-sonnet-4-5, google/gemini-2.5-pro" + /> + +
+

+ Leave empty to discover all models. Recommended for providers with large catalogs + (e.g. OpenRouter). +

+
+ ) : null} + +
+ setManualModelId(event.target.value)} + onKeyDown={(event) => { + if (event.key === "Enter") { + event.preventDefault(); + addModel(); + } + }} + placeholder="Add a model ID manually (for providers without /models)" + /> + +
+ +
+ {connection.models.map((model) => ( +
+
+
+ {getProviderIcon(providerLabel, { className: "size-4" })} + {modelLabel(model)} + {model.source === "MANUAL" ? ( + + manual + + ) : null} +
+
+ {["chat", "vision", "image_gen"] + .filter((key) => Boolean(model.capabilities?.[key])) + .join(", ") || "No verified capabilities"} +
+
+
+ + +
+
+ ))} +
+
+ ); +} + export function ModelConnectionsSettings({ searchSpaceId }: { searchSpaceId: number }) { const [{ data: globalConnections = [] }] = useAtom(globalModelConnectionsAtom); const [{ data: connections = [] }] = useAtom(modelConnectionsAtom); const [{ data: roles }] = useAtom(modelRolesAtom); const createConnection = useAtomValue(createModelConnectionMutationAtom); - const verifyConnection = useAtomValue(verifyModelConnectionMutationAtom); - const discoverModels = useAtomValue(discoverConnectionModelsMutationAtom); - const updateModel = useAtomValue(updateModelMutationAtom); - const testModel = useAtomValue(testModelMutationAtom); const updateRoles = useAtomValue(updateModelRolesMutationAtom); const visiblePresets = useMemo( () => PRESETS.filter((preset) => !(isCloud() && preset.local)), [] ); - const [presetId, setPresetId] = useState(visiblePresets[0]?.id ?? "openai"); + const [presetId, setPresetId] = useState(visiblePresets[0]?.id ?? "custom"); const preset = visiblePresets.find((item) => item.id === presetId) ?? visiblePresets[0]; const [baseUrl, setBaseUrl] = useState(preset?.baseUrl ?? ""); const [apiKey, setApiKey] = useState(""); @@ -157,16 +343,19 @@ export function ModelConnectionsSettings({ searchSpaceId }: { searchSpaceId: num function handleCreate() { if (!preset) return; - createConnection.mutate({ - protocol: preset.protocol, - native_provider: preset.nativeProvider, - base_url: baseUrl || null, - api_key: apiKey || null, - scope: "SEARCH_SPACE", - search_space_id: searchSpaceId, - extra: {}, - enabled: true, - }); + createConnection.mutate( + { + protocol: preset.protocol, + native_provider: preset.nativeProvider, + base_url: baseUrl || null, + api_key: apiKey || null, + scope: "SEARCH_SPACE", + search_space_id: searchSpaceId, + extra: {}, + enabled: true, + }, + { onSuccess: () => setApiKey("") } + ); } function renderModelOption(model: ModelRead & { connectionName: string; provider: string }) { @@ -192,7 +381,7 @@ export function ModelConnectionsSettings({ searchSpaceId }: { searchSpaceId: num
- +
- - setBaseUrl(event.target.value)} - placeholder="https://api.example.com/v1" - list="model-conn-url-suggestions" - /> - - {URL_SUGGESTIONS.map((url) => ( - + + {isNative && !showCustomEndpoint ? ( +
+
+ Uses provider default +
+ +
+ ) : ( + <> + setBaseUrl(event.target.value)} + placeholder="https://api.example.com/v1" + list="model-conn-url-suggestions" + /> + + {URL_SUGGESTIONS.map((url) => ( + + + )}
- + setApiKey(event.target.value)} - placeholder="Optional for local models" + placeholder={preset?.local ? "Optional for local models" : "API key"} type="password" />
-
@@ -434,10 +457,15 @@ export function ModelConnectionsSettings({ searchSpaceId }: { searchSpaceId: num Local URLs are tested from the backend container. Use host.docker.internal instead of localhost.

+ ) : isNative ? ( +

+ Just paste an API key — {preset?.label} routes through its native endpoint + automatically. After adding, hit Discover (or add model IDs manually). +

) : preset?.protocol === "OPENAI_COMPATIBLE" ? (

- Works with any OpenAI-compatible endpoint (OpenRouter, Together, Groq, vLLM, LM - Studio…). After adding, hit Discover to list models. + Enter any OpenAI-compatible endpoint (OpenRouter, Together, Groq, vLLM, LM Studio…). + After adding, hit Discover to list models.

) : null} From 50c816c81c0b85e3c6f28fe4719e0be85d183cac Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Thu, 11 Jun 2026 10:22:39 +0530 Subject: [PATCH 015/212] refactor(model-connections): streamline connection reading and model handling in routes --- .../app/routes/model_connections_routes.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/surfsense_backend/app/routes/model_connections_routes.py b/surfsense_backend/app/routes/model_connections_routes.py index 6d19a5ed1..5872671b1 100644 --- a/surfsense_backend/app/routes/model_connections_routes.py +++ b/surfsense_backend/app/routes/model_connections_routes.py @@ -69,7 +69,7 @@ def _connection_read(conn: Connection | dict, models: list[Model | dict] | None last_verified_at=conn.last_verified_at, last_status=conn.last_status, last_error=conn.last_error, - models=[_model_read(model) for model in (models or conn.models or [])], + models=[_model_read(model) for model in (models or [])], created_at=conn.created_at, ) @@ -144,7 +144,10 @@ async def list_connections( else: stmt = stmt.where(Connection.user_id == user.id) result = await session.execute(stmt.order_by(Connection.id)) - return [_connection_read(conn) for conn in result.scalars().all()] + return [ + _connection_read(conn, list(conn.models)) + for conn in result.scalars().all() + ] @router.post("/model-connections", response_model=ConnectionRead) @@ -173,7 +176,7 @@ async def create_connection( session.add(conn) await session.commit() await session.refresh(conn) - return _connection_read(conn) + return _connection_read(conn, []) @router.put("/model-connections/{connection_id}", response_model=ConnectionRead) @@ -188,8 +191,8 @@ async def update_connection( for key, value in data.model_dump(exclude_unset=True).items(): setattr(conn, key, value) await session.commit() - await session.refresh(conn) - return _connection_read(conn) + conn = await _load_connection(session, connection_id) + return _connection_read(conn, list(conn.models)) @router.delete("/model-connections/{connection_id}") @@ -253,7 +256,7 @@ async def discover_connection_models( } db_model.catalog = item.get("metadata") or db_model.catalog await session.commit() - await session.refresh(conn) + conn = await _load_connection(session, connection_id) return [_model_read(model) for model in conn.models] From 3f016421992919a63211df2adbdc13df39278fbc Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Thu, 11 Jun 2026 17:29:55 +0530 Subject: [PATCH 016/212] feat(model-connections): enhance model discovery with OpenAI and LiteLLM support --- .../app/services/model_connection_service.py | 93 +++++++++++++++---- .../model-connections-mutation.atoms.ts | 8 +- 2 files changed, 80 insertions(+), 21 deletions(-) diff --git a/surfsense_backend/app/services/model_connection_service.py b/surfsense_backend/app/services/model_connection_service.py index c8d2e8a5a..42a4792a4 100644 --- a/surfsense_backend/app/services/model_connection_service.py +++ b/surfsense_backend/app/services/model_connection_service.py @@ -2,6 +2,7 @@ from __future__ import annotations +import asyncio import contextlib import logging from dataclasses import dataclass @@ -12,7 +13,8 @@ import httpx import litellm from app.db import Connection, ConnectionProtocol, Model, ModelSource -from app.services.model_resolver import ensure_v1, to_litellm +from app.services.model_resolver import NATIVE_PROVIDER_PREFIX, ensure_v1, to_litellm +from app.services.provider_api_base import resolve_api_base logger = logging.getLogger(__name__) @@ -133,6 +135,63 @@ def _allowlist(conn: Connection) -> set[str]: return {str(item).strip() for item in raw if str(item).strip()} +async def _discover_openai_shaped_models(conn: Connection, base_url: str | None) -> list[dict[str, Any]]: + if not base_url: + return [] + + url = f"{ensure_v1(base_url)}/models" + async with httpx.AsyncClient(timeout=DISCOVERY_TIMEOUT_SECONDS) as client: + response = await client.get(url, headers=_auth_headers(conn)) + response.raise_for_status() + return [ + { + "model_id": item.get("id"), + "display_name": item.get("name") or item.get("id"), + "source": ModelSource.DISCOVERED, + "capabilities": derive_capabilities(conn, item.get("id"), item), + "metadata": item, + } + for item in response.json().get("data", []) + if item.get("id") + ] + + +def _litellm_valid_model_ids(provider: str, api_key: str | None) -> list[str]: + if not api_key: + return [] + + try: + models = litellm.get_valid_models( + check_provider_endpoint=True, + custom_llm_provider=provider, + api_key=api_key, + ) + except Exception as exc: + logger.warning("LiteLLM model discovery failed for provider %s: %s", provider, exc) + return [] + + provider_prefix = f"{provider}/" + return [ + model.removeprefix(provider_prefix) + for model in models + if isinstance(model, str) and model.strip() + ] + + +async def _discover_litellm_native_models(conn: Connection, provider: str) -> list[dict[str, Any]]: + model_ids = await asyncio.to_thread(_litellm_valid_model_ids, provider, conn.api_key) + return [ + { + "model_id": model_id, + "display_name": model_id, + "source": ModelSource.DISCOVERED, + "capabilities": derive_capabilities(conn, model_id), + "metadata": {}, + } + for model_id in model_ids + ] + + def derive_capabilities(conn: Connection, model_id: str, metadata: dict | None = None) -> dict[str, bool]: metadata = metadata or {} if conn.protocol == ConnectionProtocol.OLLAMA: @@ -171,25 +230,21 @@ async def discover_models(conn: Connection) -> list[dict[str, Any]]: if item.get("model") or item.get("name") ] elif conn.protocol == ConnectionProtocol.OPENAI_COMPATIBLE: - url = f"{ensure_v1(conn.base_url)}/models" - async with httpx.AsyncClient(timeout=DISCOVERY_TIMEOUT_SECONDS) as client: - response = await client.get(url, headers=_auth_headers(conn)) - response.raise_for_status() - models = response.json().get("data", []) - results = [ - { - "model_id": item.get("id"), - "display_name": item.get("name") or item.get("id"), - "source": ModelSource.DISCOVERED, - "capabilities": derive_capabilities(conn, item.get("id"), item), - "metadata": item, - } - for item in models - if item.get("id") - ] + results = await _discover_openai_shaped_models(conn, conn.base_url) else: - # Native providers rely on curated/global catalog entries or manual rows. - return [] + provider_key = (conn.native_provider or "").upper() + provider = NATIVE_PROVIDER_PREFIX.get(provider_key, provider_key.lower()) + api_base = resolve_api_base( + provider=provider_key, + provider_prefix=provider, + config_api_base=conn.base_url, + ) + if api_base: + results = await _discover_openai_shaped_models(conn, api_base) + elif provider: + results = await _discover_litellm_native_models(conn, provider) + else: + results = [] if allowlist: results = [item for item in results if item["model_id"] in allowlist] diff --git a/surfsense_web/atoms/model-connections/model-connections-mutation.atoms.ts b/surfsense_web/atoms/model-connections/model-connections-mutation.atoms.ts index 612216bf2..76289e60d 100644 --- a/surfsense_web/atoms/model-connections/model-connections-mutation.atoms.ts +++ b/surfsense_web/atoms/model-connections/model-connections-mutation.atoms.ts @@ -90,8 +90,12 @@ export const discoverConnectionModelsMutationAtom = atomWithMutation((get) => { return { mutationKey: ["model-connections", "discover"], mutationFn: (id: number) => modelConnectionsApiService.discoverModels(id), - onSuccess: () => { - toast.success("Models discovered"); + onSuccess: (models) => { + toast.success( + models.length + ? `${models.length} models discovered` + : "No models found for this connection" + ); invalidateModelConnections(searchSpaceId); }, onError: (error: Error) => toast.error(error.message || "Failed to discover models"), From c6a25cc1fe1e128852146a09d2877d3496a2aaee Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Thu, 11 Jun 2026 18:20:53 +0530 Subject: [PATCH 017/212] refactor(model-connections): streamline global model config persistence --- .../versions/156_add_model_connections.py | 266 +++++++++++------- surfsense_backend/app/config/__init__.py | 9 +- .../app/config/global_llm_config.example.yaml | 60 ++-- surfsense_backend/app/db.py | 4 +- .../app/routes/model_connections_routes.py | 19 +- .../app/routes/new_llm_config_routes.py | 6 +- .../app/routes/search_spaces_routes.py | 6 +- .../app/schemas/model_connections.py | 6 +- .../app/services/model_connection_service.py | 85 +++--- .../tests/e2e/fixtures/global_llm_config.yaml | 8 +- .../routes/test_global_configs_is_premium.py | 12 +- ...t_global_new_llm_configs_supports_image.py | 8 +- .../unit/services/test_model_connections.py | 12 +- 13 files changed, 277 insertions(+), 224 deletions(-) diff --git a/surfsense_backend/alembic/versions/156_add_model_connections.py b/surfsense_backend/alembic/versions/156_add_model_connections.py index 0a11d7f9d..185debca4 100644 --- a/surfsense_backend/alembic/versions/156_add_model_connections.py +++ b/surfsense_backend/alembic/versions/156_add_model_connections.py @@ -20,7 +20,7 @@ depends_on: str | Sequence[str] | None = None connection_protocol = postgresql.ENUM( "OLLAMA", "OPENAI_COMPATIBLE", - "NATIVE", + "ANTHROPIC", name="connectionprotocol", create_type=False, ) @@ -39,122 +39,172 @@ model_source = postgresql.ENUM( ) +def _table_exists(table_name: str) -> bool: + return table_name in sa.inspect(op.get_bind()).get_table_names() + + +def _column_exists(table_name: str, column_name: str) -> bool: + if not _table_exists(table_name): + return False + return column_name in { + column["name"] for column in sa.inspect(op.get_bind()).get_columns(table_name) + } + + +def _index_exists(table_name: str, index_name: str) -> bool: + if not _table_exists(table_name): + return False + return index_name in { + index["name"] for index in sa.inspect(op.get_bind()).get_indexes(table_name) + } + + +def _create_index_if_missing( + index_name: str, + table_name: str, + columns: list[str], +) -> None: + if not _index_exists(table_name, index_name): + op.create_index(index_name, table_name, columns, unique=False) + + +def _add_searchspace_column_if_missing(column_name: str) -> None: + if not _column_exists("searchspaces", column_name): + op.add_column("searchspaces", sa.Column(column_name, sa.Integer(), nullable=True)) + + def upgrade() -> None: bind = op.get_bind() connection_protocol.create(bind, checkfirst=True) + op.execute("ALTER TYPE connectionprotocol ADD VALUE IF NOT EXISTS 'ANTHROPIC'") connection_scope.create(bind, checkfirst=True) model_source.create(bind, checkfirst=True) - op.create_table( + if _table_exists("connections"): + if _column_exists("connections", "native_provider") and not _column_exists( + "connections", "litellm_provider" + ): + op.alter_column( + "connections", + "native_provider", + new_column_name="litellm_provider", + existing_type=sa.String(length=100), + existing_nullable=True, + ) + elif not _column_exists("connections", "litellm_provider"): + op.add_column( + "connections", + sa.Column("litellm_provider", sa.String(length=100), nullable=True), + ) + else: + op.create_table( + "connections", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("created_at", sa.DateTime(timezone=True), nullable=False), + sa.Column("protocol", connection_protocol, nullable=False), + sa.Column("litellm_provider", sa.String(length=100), nullable=True), + sa.Column("base_url", sa.String(length=500), nullable=True), + sa.Column("api_key", sa.String(), nullable=True), + sa.Column( + "extra", + postgresql.JSONB(astext_type=sa.Text()), + server_default=sa.text("'{}'::jsonb"), + nullable=False, + ), + sa.Column("scope", connection_scope, nullable=False), + sa.Column("enabled", sa.Boolean(), server_default=sa.text("true"), nullable=False), + sa.Column("search_space_id", sa.Integer(), nullable=True), + sa.Column("user_id", sa.UUID(), nullable=True), + sa.Column("last_verified_at", sa.TIMESTAMP(timezone=True), nullable=True), + sa.Column("last_status", sa.String(length=50), nullable=True), + sa.Column("last_error", sa.Text(), nullable=True), + sa.CheckConstraint( + "(scope = 'GLOBAL' AND search_space_id IS NULL AND user_id IS NULL) OR " + "(scope = 'SEARCH_SPACE' AND search_space_id IS NOT NULL AND user_id IS NOT NULL) OR " + "(scope = 'USER' AND user_id IS NOT NULL)", + name="ck_connections_scope_owner", + ), + sa.ForeignKeyConstraint( + ["search_space_id"], ["searchspaces.id"], ondelete="CASCADE" + ), + sa.ForeignKeyConstraint(["user_id"], ["user.id"], ondelete="CASCADE"), + sa.PrimaryKeyConstraint("id"), + ) + if _index_exists("connections", "ix_connections_native_provider") and not _index_exists( + "connections", "ix_connections_litellm_provider" + ): + op.execute( + "ALTER INDEX ix_connections_native_provider " + "RENAME TO ix_connections_litellm_provider" + ) + _create_index_if_missing("ix_connections_protocol", "connections", ["protocol"]) + _create_index_if_missing( + "ix_connections_litellm_provider", "connections", - sa.Column("id", sa.Integer(), nullable=False), - sa.Column("created_at", sa.DateTime(timezone=True), nullable=False), - sa.Column("protocol", connection_protocol, nullable=False), - sa.Column("native_provider", sa.String(length=100), nullable=True), - sa.Column("base_url", sa.String(length=500), nullable=True), - sa.Column("api_key", sa.String(), nullable=True), - sa.Column( - "extra", - postgresql.JSONB(astext_type=sa.Text()), - server_default=sa.text("'{}'::jsonb"), - nullable=False, - ), - sa.Column("scope", connection_scope, nullable=False), - sa.Column("enabled", sa.Boolean(), server_default=sa.text("true"), nullable=False), - sa.Column("search_space_id", sa.Integer(), nullable=True), - sa.Column("user_id", sa.UUID(), nullable=True), - sa.Column("last_verified_at", sa.TIMESTAMP(timezone=True), nullable=True), - sa.Column("last_status", sa.String(length=50), nullable=True), - sa.Column("last_error", sa.Text(), nullable=True), - sa.CheckConstraint( - "(scope = 'GLOBAL' AND search_space_id IS NULL AND user_id IS NULL) OR " - "(scope = 'SEARCH_SPACE' AND search_space_id IS NOT NULL AND user_id IS NOT NULL) OR " - "(scope = 'USER' AND user_id IS NOT NULL)", - name="ck_connections_scope_owner", - ), - sa.ForeignKeyConstraint( - ["search_space_id"], ["searchspaces.id"], ondelete="CASCADE" - ), - sa.ForeignKeyConstraint(["user_id"], ["user.id"], ondelete="CASCADE"), - sa.PrimaryKeyConstraint("id"), + ["litellm_provider"], ) - op.create_index(op.f("ix_connections_protocol"), "connections", ["protocol"], unique=False) - op.create_index( - op.f("ix_connections_native_provider"), - "connections", - ["native_provider"], - unique=False, - ) - op.create_index(op.f("ix_connections_scope"), "connections", ["scope"], unique=False) + _create_index_if_missing("ix_connections_scope", "connections", ["scope"]) - op.create_table( - "models", - sa.Column("id", sa.Integer(), nullable=False), - sa.Column("created_at", sa.DateTime(timezone=True), nullable=False), - sa.Column("connection_id", sa.Integer(), nullable=False), - sa.Column("model_id", sa.String(length=255), nullable=False), - sa.Column("display_name", sa.String(length=255), nullable=True), - sa.Column( - "source", - model_source, - server_default="DISCOVERED", - nullable=False, - ), - sa.Column( - "capabilities", - postgresql.JSONB(astext_type=sa.Text()), - server_default=sa.text("'{}'::jsonb"), - nullable=False, - ), - sa.Column( - "capabilities_declared", - postgresql.JSONB(astext_type=sa.Text()), - server_default=sa.text("'{}'::jsonb"), - nullable=False, - ), - sa.Column( - "capabilities_verified", - postgresql.JSONB(astext_type=sa.Text()), - server_default=sa.text("'{}'::jsonb"), - nullable=False, - ), - sa.Column( - "capabilities_override", - postgresql.JSONB(astext_type=sa.Text()), - server_default=sa.text("'{}'::jsonb"), - nullable=False, - ), - sa.Column("embedding_dimension", sa.Integer(), nullable=True), - sa.Column("enabled", sa.Boolean(), server_default=sa.text("true"), nullable=False), - sa.Column("billing_tier", sa.String(length=50), nullable=True), - sa.Column( - "catalog", - postgresql.JSONB(astext_type=sa.Text()), - server_default=sa.text("'{}'::jsonb"), - nullable=False, - ), - sa.ForeignKeyConstraint(["connection_id"], ["connections.id"], ondelete="CASCADE"), - sa.PrimaryKeyConstraint("id"), - sa.UniqueConstraint( - "connection_id", "model_id", name="uq_models_connection_model_id" - ), - ) - op.create_index(op.f("ix_models_connection_id"), "models", ["connection_id"], unique=False) - op.create_index("ix_models_model_id", "models", ["model_id"], unique=False) - op.create_index(op.f("ix_models_billing_tier"), "models", ["billing_tier"], unique=False) + if not _table_exists("models"): + op.create_table( + "models", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("created_at", sa.DateTime(timezone=True), nullable=False), + sa.Column("connection_id", sa.Integer(), nullable=False), + sa.Column("model_id", sa.String(length=255), nullable=False), + sa.Column("display_name", sa.String(length=255), nullable=True), + sa.Column( + "source", + model_source, + server_default="DISCOVERED", + nullable=False, + ), + sa.Column( + "capabilities", + postgresql.JSONB(astext_type=sa.Text()), + server_default=sa.text("'{}'::jsonb"), + nullable=False, + ), + sa.Column( + "capabilities_declared", + postgresql.JSONB(astext_type=sa.Text()), + server_default=sa.text("'{}'::jsonb"), + nullable=False, + ), + sa.Column( + "capabilities_verified", + postgresql.JSONB(astext_type=sa.Text()), + server_default=sa.text("'{}'::jsonb"), + nullable=False, + ), + sa.Column( + "capabilities_override", + postgresql.JSONB(astext_type=sa.Text()), + server_default=sa.text("'{}'::jsonb"), + nullable=False, + ), + sa.Column("embedding_dimension", sa.Integer(), nullable=True), + sa.Column("enabled", sa.Boolean(), server_default=sa.text("true"), nullable=False), + sa.Column("billing_tier", sa.String(length=50), nullable=True), + sa.Column( + "catalog", + postgresql.JSONB(astext_type=sa.Text()), + server_default=sa.text("'{}'::jsonb"), + nullable=False, + ), + sa.ForeignKeyConstraint(["connection_id"], ["connections.id"], ondelete="CASCADE"), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint( + "connection_id", "model_id", name="uq_models_connection_model_id" + ), + ) + _create_index_if_missing("ix_models_connection_id", "models", ["connection_id"]) + _create_index_if_missing("ix_models_model_id", "models", ["model_id"]) + _create_index_if_missing("ix_models_billing_tier", "models", ["billing_tier"]) - op.add_column( - "searchspaces", - sa.Column("chat_model_id", sa.Integer(), nullable=True), - ) - op.add_column( - "searchspaces", - sa.Column("image_gen_model_id", sa.Integer(), nullable=True), - ) - op.add_column( - "searchspaces", - sa.Column("vision_model_id", sa.Integer(), nullable=True), - ) + _add_searchspace_column_if_missing("chat_model_id") + _add_searchspace_column_if_missing("image_gen_model_id") + _add_searchspace_column_if_missing("vision_model_id") def downgrade() -> None: @@ -168,7 +218,7 @@ def downgrade() -> None: op.drop_table("models") op.drop_index(op.f("ix_connections_scope"), table_name="connections") - op.drop_index(op.f("ix_connections_native_provider"), table_name="connections") + op.drop_index(op.f("ix_connections_litellm_provider"), table_name="connections") op.drop_index(op.f("ix_connections_protocol"), table_name="connections") op.drop_table("connections") diff --git a/surfsense_backend/app/config/__init__.py b/surfsense_backend/app/config/__init__.py index b8addb45d..fd8b29116 100644 --- a/surfsense_backend/app/config/__init__.py +++ b/surfsense_backend/app/config/__init__.py @@ -78,8 +78,7 @@ def load_global_llm_configs(): # stamps) never leak into the cached YAML structure. configs = copy.deepcopy(data.get("global_llm_configs", [])) - # Lazy import keeps the `app.config` -> `app.services` edge one-way - # and matches the `provider_api_base` pattern used elsewhere. + # Lazy import keeps the `app.config` -> `app.services` edge one-way. from app.services.provider_capabilities import derive_supports_image_input seen_slugs: dict[str, int] = {} @@ -104,7 +103,7 @@ def load_global_llm_configs(): else None ) cfg["supports_image_input"] = derive_supports_image_input( - provider=cfg.get("provider"), + litellm_provider=cfg.get("litellm_provider"), model_name=cfg.get("model_name"), base_model=base_model, custom_provider=cfg.get("custom_provider"), @@ -123,7 +122,7 @@ def load_global_llm_configs(): # Stamp Auto (Fastest) ranking metadata. YAML configs are always # Tier A — operator-curated, locked first when premium-eligible. # The OpenRouter refresh tick later re-stamps health for any cfg - # whose provider == "OPENROUTER" via _enrich_health. + # whose litellm_provider == "openrouter" via _enrich_health. try: from app.services.quality_score import static_score_yaml @@ -133,7 +132,7 @@ def load_global_llm_configs(): cfg["quality_score_static"] = static_q cfg["quality_score"] = static_q cfg["quality_score_health"] = None - # YAML cfgs whose provider is OPENROUTER are also subject + # YAML cfgs whose litellm_provider is openrouter are also subject # to health gating against their own /endpoints data — a # hand-picked dead OR model is still dead. _enrich_health # re-stamps health_gated for them on the next refresh tick. diff --git a/surfsense_backend/app/config/global_llm_config.example.yaml b/surfsense_backend/app/config/global_llm_config.example.yaml index b0eee6458..06676511f 100644 --- a/surfsense_backend/app/config/global_llm_config.example.yaml +++ b/surfsense_backend/app/config/global_llm_config.example.yaml @@ -18,7 +18,7 @@ # - Configure router_settings below to customize the load balancing behavior # # Static config shape: -# - Connection fields: provider, api_key, api_base, api_version +# - Connection fields: litellm_provider, api_key, api_base, api_version # - Model fields: model_name, billing_tier, rpm/tpm, litellm_params # - Prompt defaults: system_instructions, citations_enabled # IDs share one GLOBAL model namespace across chat, vision, and image generation. @@ -75,10 +75,10 @@ global_llm_configs: seo_enabled: true seo_slug: "gpt-4-turbo" quota_reserve_tokens: 4000 - provider: "OPENAI" + litellm_provider: "openai" model_name: "gpt-4-turbo-preview" api_key: "sk-your-openai-api-key-here" - api_base: "" + api_base: "https://api.openai.com/v1" # Rate limits for load balancing (requests/tokens per minute) rpm: 500 # Requests per minute tpm: 100000 # Tokens per minute @@ -99,10 +99,10 @@ global_llm_configs: seo_enabled: true seo_slug: "claude-3-opus" quota_reserve_tokens: 4000 - provider: "ANTHROPIC" + litellm_provider: "anthropic" model_name: "claude-3-opus-20240229" api_key: "sk-ant-your-anthropic-api-key-here" - api_base: "" + api_base: "https://api.anthropic.com/v1" rpm: 1000 tpm: 100000 litellm_params: @@ -121,10 +121,10 @@ global_llm_configs: seo_enabled: true seo_slug: "gpt-3.5-turbo-fast" quota_reserve_tokens: 2000 - provider: "OPENAI" + litellm_provider: "openai" model_name: "gpt-3.5-turbo" api_key: "sk-your-openai-api-key-here" - api_base: "" + api_base: "https://api.openai.com/v1" rpm: 3500 # GPT-3.5 has higher rate limits tpm: 200000 litellm_params: @@ -143,7 +143,7 @@ global_llm_configs: seo_enabled: true seo_slug: "deepseek-chat-chinese" quota_reserve_tokens: 4000 - provider: "DEEPSEEK" + litellm_provider: "openai" model_name: "deepseek-chat" api_key: "your-deepseek-api-key-here" api_base: "https://api.deepseek.com/v1" @@ -175,7 +175,7 @@ global_llm_configs: seo_enabled: true seo_slug: "azure-gpt-4o" quota_reserve_tokens: 4000 - provider: "AZURE" + litellm_provider: "azure" # model_name format for Azure: azure/ model_name: "azure/gpt-4o-deployment" api_key: "your-azure-api-key-here" @@ -203,7 +203,7 @@ global_llm_configs: seo_enabled: true seo_slug: "azure-gpt-4-turbo" quota_reserve_tokens: 4000 - provider: "AZURE" + litellm_provider: "azure" model_name: "azure/gpt-4-turbo-deployment" api_key: "your-azure-api-key-here" api_base: "https://your-resource.openai.azure.com" @@ -227,10 +227,10 @@ global_llm_configs: seo_enabled: true seo_slug: "groq-llama-3" quota_reserve_tokens: 8000 - provider: "GROQ" + litellm_provider: "groq" model_name: "llama3-70b-8192" api_key: "your-groq-api-key-here" - api_base: "" + api_base: "https://api.groq.com/openai/v1" rpm: 30 # Groq has lower rate limits on free tier tpm: 14400 litellm_params: @@ -249,7 +249,7 @@ global_llm_configs: seo_enabled: true seo_slug: "minimax-m3" quota_reserve_tokens: 4000 - provider: "MINIMAX" + litellm_provider: "openai" model_name: "MiniMax-M3" api_key: "your-minimax-api-key-here" api_base: "https://api.minimax.io/v1" @@ -288,10 +288,10 @@ global_llm_configs: anonymous_enabled: false seo_enabled: false quota_reserve_tokens: 1000 - provider: "OPENAI" + litellm_provider: "openai" model_name: "gpt-4o-mini" api_key: "sk-your-openai-api-key-here" - api_base: "" + api_base: "https://api.openai.com/v1" rpm: 3500 tpm: 200000 litellm_params: @@ -391,10 +391,10 @@ global_image_generation_configs: - id: -2001 name: "Global DALL-E 3" description: "OpenAI's DALL-E 3 for high-quality image generation" - provider: "OPENAI" + litellm_provider: "openai" model_name: "dall-e-3" api_key: "sk-your-openai-api-key-here" - api_base: "" + api_base: "https://api.openai.com/v1" rpm: 50 # Requests per minute (image gen is rate-limited by RPM, not tokens) litellm_params: {} @@ -402,10 +402,10 @@ global_image_generation_configs: - id: -2002 name: "Global GPT Image 1" description: "OpenAI's GPT Image 1 model" - provider: "OPENAI" + litellm_provider: "openai" model_name: "gpt-image-1" api_key: "sk-your-openai-api-key-here" - api_base: "" + api_base: "https://api.openai.com/v1" rpm: 50 litellm_params: {} @@ -413,7 +413,7 @@ global_image_generation_configs: - id: -2003 name: "Global Azure DALL-E 3" description: "Azure-hosted DALL-E 3 deployment" - provider: "AZURE_OPENAI" + litellm_provider: "azure" model_name: "azure/dall-e-3-deployment" api_key: "your-azure-api-key-here" api_base: "https://your-resource.openai.azure.com" @@ -426,10 +426,10 @@ global_image_generation_configs: # - id: -2004 # name: "Global Gemini Image Gen" # description: "Google Gemini image generation via OpenRouter" - # provider: "OPENROUTER" + # litellm_provider: "openrouter" # model_name: "google/gemini-2.5-flash-image" # api_key: "your-openrouter-api-key-here" - # api_base: "" + # api_base: "https://openrouter.ai/api/v1" # rpm: 30 # litellm_params: {} @@ -455,10 +455,10 @@ global_vision_llm_configs: - id: -1001 name: "Global GPT-4o Vision" description: "OpenAI's GPT-4o with strong vision capabilities" - provider: "OPENAI" + litellm_provider: "openai" model_name: "gpt-4o" api_key: "sk-your-openai-api-key-here" - api_base: "" + api_base: "https://api.openai.com/v1" rpm: 500 tpm: 100000 litellm_params: @@ -469,10 +469,10 @@ global_vision_llm_configs: - id: -1002 name: "Global Gemini 2.0 Flash" description: "Google's fast vision model with large context" - provider: "GOOGLE" + litellm_provider: "gemini" model_name: "gemini-2.0-flash" api_key: "your-google-ai-api-key-here" - api_base: "" + api_base: "https://generativelanguage.googleapis.com/v1beta" rpm: 1000 tpm: 200000 litellm_params: @@ -483,10 +483,10 @@ global_vision_llm_configs: - id: -1003 name: "Global Claude 3.5 Sonnet Vision" description: "Anthropic's Claude 3.5 Sonnet with vision support" - provider: "ANTHROPIC" + litellm_provider: "anthropic" model_name: "claude-3-5-sonnet-20241022" api_key: "sk-ant-your-anthropic-api-key-here" - api_base: "" + api_base: "https://api.anthropic.com/v1" rpm: 1000 tpm: 100000 litellm_params: @@ -497,7 +497,7 @@ global_vision_llm_configs: # - id: -1004 # name: "Global Azure GPT-4o Vision" # description: "Azure-hosted GPT-4o for vision analysis" - # provider: "AZURE_OPENAI" + # litellm_provider: "azure" # model_name: "azure/gpt-4o-deployment" # api_key: "your-azure-api-key-here" # api_base: "https://your-resource.openai.azure.com" @@ -518,7 +518,7 @@ global_vision_llm_configs: # - system_instructions: Custom prompt or empty string to use defaults # - use_default_system_instructions: true = use SURFSENSE_SYSTEM_INSTRUCTIONS when system_instructions is empty # - citations_enabled: true = include citation instructions, false = include anti-citation instructions -# - All standard LiteLLM providers are supported +# - All standard LiteLLM provider adapter names are supported # - rpm/tpm: Optional rate limits for load balancing (requests/tokens per minute) # These help the router distribute load evenly and avoid rate limit errors # diff --git a/surfsense_backend/app/db.py b/surfsense_backend/app/db.py index 9756cb32f..4c628b05a 100644 --- a/surfsense_backend/app/db.py +++ b/surfsense_backend/app/db.py @@ -283,7 +283,7 @@ class VisionProvider(StrEnum): class ConnectionProtocol(StrEnum): OLLAMA = "OLLAMA" OPENAI_COMPATIBLE = "OPENAI_COMPATIBLE" - NATIVE = "NATIVE" + ANTHROPIC = "ANTHROPIC" class ConnectionScope(StrEnum): @@ -1663,7 +1663,7 @@ class Connection(BaseModel, TimestampMixin): __tablename__ = "connections" protocol = Column(SQLAlchemyEnum(ConnectionProtocol), nullable=False, index=True) - native_provider = Column(String(100), nullable=True, index=True) + litellm_provider = Column(String(100), nullable=True, index=True) base_url = Column(String(500), nullable=True) api_key = Column(String, nullable=True) extra = Column(JSONB, nullable=False, default=dict, server_default="{}") diff --git a/surfsense_backend/app/routes/model_connections_routes.py b/surfsense_backend/app/routes/model_connections_routes.py index 5872671b1..cae951c3a 100644 --- a/surfsense_backend/app/routes/model_connections_routes.py +++ b/surfsense_backend/app/routes/model_connections_routes.py @@ -8,6 +8,7 @@ from sqlalchemy.orm import selectinload from app.config import config from app.db import ( Connection, + ConnectionProtocol, ConnectionScope, Model, ModelSource, @@ -40,6 +41,16 @@ router = APIRouter() logger = logging.getLogger(__name__) +def _default_litellm_provider(protocol: ConnectionProtocol | str) -> str: + protocol_value = getattr(protocol, "value", str(protocol)) + defaults = { + ConnectionProtocol.OLLAMA.value: "ollama_chat", + ConnectionProtocol.ANTHROPIC.value: "anthropic", + ConnectionProtocol.OPENAI_COMPATIBLE.value: "openai", + } + return defaults.get(protocol_value, "openai") + + def _model_read(model: Model | dict) -> ModelRead: return ModelRead.model_validate(model) @@ -58,7 +69,7 @@ def _connection_read(conn: Connection | dict, models: list[Model | dict] | None return ConnectionRead( id=conn.id, protocol=conn.protocol, - native_provider=conn.native_provider, + litellm_provider=conn.litellm_provider, base_url=conn.base_url, extra=conn.extra or {}, scope=conn.scope, @@ -168,8 +179,12 @@ async def create_connection( Permission.LLM_CONFIGS_CREATE.value, "You don't have permission to create model connections in this search space", ) + payload = data.model_dump(exclude={"search_space_id"}) + if not payload.get("litellm_provider"): + payload["litellm_provider"] = _default_litellm_provider(data.protocol) + conn = Connection( - **data.model_dump(exclude={"search_space_id"}), + **payload, search_space_id=data.search_space_id if data.scope == ConnectionScope.SEARCH_SPACE else None, user_id=user.id, ) diff --git a/surfsense_backend/app/routes/new_llm_config_routes.py b/surfsense_backend/app/routes/new_llm_config_routes.py index 84d66bb13..531fa8730 100644 --- a/surfsense_backend/app/routes/new_llm_config_routes.py +++ b/surfsense_backend/app/routes/new_llm_config_routes.py @@ -57,7 +57,7 @@ def _serialize_byok_config(config: NewLLMConfig) -> NewLLMConfigRead: litellm_params.get("base_model") if isinstance(litellm_params, dict) else None ) supports_image_input = derive_supports_image_input( - provider=provider_value, + litellm_provider=provider_value.lower(), model_name=config.model_name, base_model=base_model, custom_provider=config.custom_provider, @@ -147,7 +147,7 @@ async def get_global_new_llm_configs( else None ) supports_image_input = derive_supports_image_input( - provider=cfg.get("provider"), + litellm_provider=cfg.get("litellm_provider"), model_name=cfg.get("model_name"), base_model=cfg_base_model, custom_provider=cfg.get("custom_provider"), @@ -157,7 +157,7 @@ async def get_global_new_llm_configs( "id": cfg.get("id"), "name": cfg.get("name"), "description": cfg.get("description"), - "provider": cfg.get("provider"), + "provider": cfg.get("litellm_provider"), "custom_provider": cfg.get("custom_provider"), "model_name": cfg.get("model_name"), "api_base": cfg.get("api_base") or None, diff --git a/surfsense_backend/app/routes/search_spaces_routes.py b/surfsense_backend/app/routes/search_spaces_routes.py index 898077b7a..2cda04221 100644 --- a/surfsense_backend/app/routes/search_spaces_routes.py +++ b/surfsense_backend/app/routes/search_spaces_routes.py @@ -419,7 +419,7 @@ async def _get_llm_config_by_id( "id": cfg.get("id"), "name": cfg.get("name"), "description": cfg.get("description"), - "provider": cfg.get("provider"), + "provider": cfg.get("litellm_provider"), "custom_provider": cfg.get("custom_provider"), "model_name": cfg.get("model_name"), "api_base": cfg.get("api_base"), @@ -490,7 +490,7 @@ async def _get_image_gen_config_by_id( "id": cfg.get("id"), "name": cfg.get("name"), "description": cfg.get("description"), - "provider": cfg.get("provider"), + "provider": cfg.get("litellm_provider"), "custom_provider": cfg.get("custom_provider"), "model_name": cfg.get("model_name"), "api_base": cfg.get("api_base") or None, @@ -550,7 +550,7 @@ async def _get_vision_llm_config_by_id( "id": cfg.get("id"), "name": cfg.get("name"), "description": cfg.get("description"), - "provider": cfg.get("provider"), + "provider": cfg.get("litellm_provider"), "custom_provider": cfg.get("custom_provider"), "model_name": cfg.get("model_name"), "api_base": cfg.get("api_base") or None, diff --git a/surfsense_backend/app/schemas/model_connections.py b/surfsense_backend/app/schemas/model_connections.py index ea1ec4e88..306dd63c8 100644 --- a/surfsense_backend/app/schemas/model_connections.py +++ b/surfsense_backend/app/schemas/model_connections.py @@ -29,7 +29,7 @@ class ModelRead(BaseModel): class ConnectionRead(BaseModel): id: int protocol: ConnectionProtocol | str - native_provider: str | None = None + litellm_provider: str | None = None base_url: str | None = None extra: dict[str, Any] = Field(default_factory=dict) scope: ConnectionScope | str @@ -48,7 +48,7 @@ class ConnectionRead(BaseModel): class ConnectionCreate(BaseModel): protocol: ConnectionProtocol - native_provider: str | None = None + litellm_provider: str | None = Field(None, max_length=100) base_url: str | None = Field(None, max_length=500) api_key: str | None = None extra: dict[str, Any] = Field(default_factory=dict) @@ -58,7 +58,7 @@ class ConnectionCreate(BaseModel): class ConnectionUpdate(BaseModel): - native_provider: str | None = None + litellm_provider: str | None = Field(None, max_length=100) base_url: str | None = Field(None, max_length=500) api_key: str | None = None extra: dict[str, Any] | None = None diff --git a/surfsense_backend/app/services/model_connection_service.py b/surfsense_backend/app/services/model_connection_service.py index 42a4792a4..5e5b231f9 100644 --- a/surfsense_backend/app/services/model_connection_service.py +++ b/surfsense_backend/app/services/model_connection_service.py @@ -2,7 +2,6 @@ from __future__ import annotations -import asyncio import contextlib import logging from dataclasses import dataclass @@ -13,8 +12,7 @@ import httpx import litellm from app.db import Connection, ConnectionProtocol, Model, ModelSource -from app.services.model_resolver import NATIVE_PROVIDER_PREFIX, ensure_v1, to_litellm -from app.services.provider_api_base import resolve_api_base +from app.services.model_resolver import ensure_v1, to_litellm logger = logging.getLogger(__name__) @@ -36,6 +34,13 @@ def _auth_headers(conn: Connection) -> dict[str, str]: return {"Authorization": f"Bearer {conn.api_key}"} +def _anthropic_headers(conn: Connection) -> dict[str, str]: + headers = {"anthropic-version": "2023-06-01"} + if conn.api_key: + headers["x-api-key"] = conn.api_key + return headers + + def _docker_hint(url: str | None, exc_or_status: Any) -> str: raw = str(exc_or_status) if not url: @@ -56,24 +61,26 @@ def _docker_hint(url: str | None, exc_or_status: Any) -> str: async def verify_connection(conn: Connection) -> VerifyResult: - if not conn.base_url and conn.protocol in ( - ConnectionProtocol.OLLAMA, - ConnectionProtocol.OPENAI_COMPATIBLE, - ): + if not conn.base_url: return VerifyResult("UNREACHABLE", False, "Base URL is required.") if conn.protocol == ConnectionProtocol.OLLAMA: url = f"{conn.base_url.rstrip('/')}/api/version" elif conn.protocol == ConnectionProtocol.OPENAI_COMPATIBLE: url = f"{ensure_v1(conn.base_url)}/models" + elif conn.protocol == ConnectionProtocol.ANTHROPIC: + url = f"{conn.base_url.rstrip('/')}/models" else: - # Native providers do not share one cheap health endpoint. The model - # probe exercises the real path and is the authoritative check. - return VerifyResult("OK", True, "Native provider configuration accepted.") + return VerifyResult("UNREACHABLE", False, "Unsupported connection protocol.") try: async with httpx.AsyncClient(timeout=VERIFY_TIMEOUT_SECONDS) as client: - response = await client.get(url, headers=_auth_headers(conn)) + headers = ( + _anthropic_headers(conn) + if conn.protocol == ConnectionProtocol.ANTHROPIC + else _auth_headers(conn) + ) + response = await client.get(url, headers=headers) if response.status_code in (401, 403): return VerifyResult("AUTH_FAILED", False, "Authentication failed.") if response.status_code == 404: @@ -156,39 +163,25 @@ async def _discover_openai_shaped_models(conn: Connection, base_url: str | None) ] -def _litellm_valid_model_ids(provider: str, api_key: str | None) -> list[str]: - if not api_key: +async def _discover_anthropic_models(conn: Connection) -> list[dict[str, Any]]: + if not conn.base_url: return [] - try: - models = litellm.get_valid_models( - check_provider_endpoint=True, - custom_llm_provider=provider, - api_key=api_key, - ) - except Exception as exc: - logger.warning("LiteLLM model discovery failed for provider %s: %s", provider, exc) - return [] - - provider_prefix = f"{provider}/" - return [ - model.removeprefix(provider_prefix) - for model in models - if isinstance(model, str) and model.strip() - ] - - -async def _discover_litellm_native_models(conn: Connection, provider: str) -> list[dict[str, Any]]: - model_ids = await asyncio.to_thread(_litellm_valid_model_ids, provider, conn.api_key) + url = f"{conn.base_url.rstrip('/')}/models" + async with httpx.AsyncClient(timeout=DISCOVERY_TIMEOUT_SECONDS) as client: + response = await client.get(url, headers=_anthropic_headers(conn)) + response.raise_for_status() + models = response.json().get("data", []) return [ { - "model_id": model_id, - "display_name": model_id, + "model_id": item.get("id"), + "display_name": item.get("display_name") or item.get("id"), "source": ModelSource.DISCOVERED, - "capabilities": derive_capabilities(conn, model_id), - "metadata": {}, + "capabilities": derive_capabilities(conn, item.get("id"), item), + "metadata": item, } - for model_id in model_ids + for item in models + if item.get("id") ] @@ -231,20 +224,10 @@ async def discover_models(conn: Connection) -> list[dict[str, Any]]: ] elif conn.protocol == ConnectionProtocol.OPENAI_COMPATIBLE: results = await _discover_openai_shaped_models(conn, conn.base_url) + elif conn.protocol == ConnectionProtocol.ANTHROPIC: + results = await _discover_anthropic_models(conn) else: - provider_key = (conn.native_provider or "").upper() - provider = NATIVE_PROVIDER_PREFIX.get(provider_key, provider_key.lower()) - api_base = resolve_api_base( - provider=provider_key, - provider_prefix=provider, - config_api_base=conn.base_url, - ) - if api_base: - results = await _discover_openai_shaped_models(conn, api_base) - elif provider: - results = await _discover_litellm_native_models(conn, provider) - else: - results = [] + results = [] if allowlist: results = [item for item in results if item["model_id"] in allowlist] diff --git a/surfsense_backend/tests/e2e/fixtures/global_llm_config.yaml b/surfsense_backend/tests/e2e/fixtures/global_llm_config.yaml index 017fa1eb3..9ea5e1a29 100644 --- a/surfsense_backend/tests/e2e/fixtures/global_llm_config.yaml +++ b/surfsense_backend/tests/e2e/fixtures/global_llm_config.yaml @@ -19,7 +19,7 @@ # so the resolved auto-pin id is never sent to a real LLM provider. # The values below only need to pass # auto_model_pin_service._is_usable_global_config() -# which requires id / model_name / provider / api_key all truthy. +# which requires id / model_name / litellm_provider / api_key all truthy. # # Why TWO entries (premium + free): # auto_model_pin_service.resolve_or_get_pinned_llm_config_id() splits @@ -44,9 +44,10 @@ global_llm_configs: anonymous_enabled: false seo_enabled: false quality_score: 1.0 - provider: "OPENAI" + litellm_provider: "openai" model_name: "fake-e2e-model-premium" api_key: "fake-e2e-api-key-not-for-production" + api_base: "https://api.openai.com/v1" supports_image_input: false quota_reserve_tokens: 1024 rpm: 1000 @@ -60,9 +61,10 @@ global_llm_configs: anonymous_enabled: false seo_enabled: false quality_score: 1.0 - provider: "OPENAI" + litellm_provider: "openai" model_name: "fake-e2e-model-free" api_key: "fake-e2e-api-key-not-for-production" + api_base: "https://api.openai.com/v1" supports_image_input: false quota_reserve_tokens: 1024 rpm: 1000 diff --git a/surfsense_backend/tests/unit/routes/test_global_configs_is_premium.py b/surfsense_backend/tests/unit/routes/test_global_configs_is_premium.py index 2b6c76485..fff61f14b 100644 --- a/surfsense_backend/tests/unit/routes/test_global_configs_is_premium.py +++ b/surfsense_backend/tests/unit/routes/test_global_configs_is_premium.py @@ -25,7 +25,7 @@ _IMAGE_FIXTURE: list[dict] = [ { "id": -1, "name": "DALL-E 3", - "provider": "OPENAI", + "litellm_provider": "openai", "model_name": "dall-e-3", "api_key": "sk-test", "billing_tier": "free", @@ -33,7 +33,7 @@ _IMAGE_FIXTURE: list[dict] = [ { "id": -2, "name": "GPT-Image 1 (premium)", - "provider": "OPENAI", + "litellm_provider": "openai", "model_name": "gpt-image-1", "api_key": "sk-test", "billing_tier": "premium", @@ -41,7 +41,7 @@ _IMAGE_FIXTURE: list[dict] = [ { "id": -20_001, "name": "google/gemini-2.5-flash-image (OpenRouter)", - "provider": "OPENROUTER", + "litellm_provider": "openrouter", "model_name": "google/gemini-2.5-flash-image", "api_key": "sk-or-test", "api_base": "https://openrouter.ai/api/v1", @@ -54,7 +54,7 @@ _VISION_FIXTURE: list[dict] = [ { "id": -1, "name": "GPT-4o Vision", - "provider": "OPENAI", + "litellm_provider": "openai", "model_name": "gpt-4o", "api_key": "sk-test", "billing_tier": "free", @@ -62,7 +62,7 @@ _VISION_FIXTURE: list[dict] = [ { "id": -2, "name": "Claude 3.5 Sonnet (premium)", - "provider": "ANTHROPIC", + "litellm_provider": "anthropic", "model_name": "claude-3-5-sonnet", "api_key": "sk-ant-test", "billing_tier": "premium", @@ -70,7 +70,7 @@ _VISION_FIXTURE: list[dict] = [ { "id": -30_001, "name": "openai/gpt-4o (OpenRouter)", - "provider": "OPENROUTER", + "litellm_provider": "openrouter", "model_name": "openai/gpt-4o", "api_key": "sk-or-test", "api_base": "https://openrouter.ai/api/v1", diff --git a/surfsense_backend/tests/unit/routes/test_global_new_llm_configs_supports_image.py b/surfsense_backend/tests/unit/routes/test_global_new_llm_configs_supports_image.py index b47d9134b..67d1112f3 100644 --- a/surfsense_backend/tests/unit/routes/test_global_new_llm_configs_supports_image.py +++ b/surfsense_backend/tests/unit/routes/test_global_new_llm_configs_supports_image.py @@ -26,7 +26,7 @@ _FIXTURE: list[dict] = [ "id": -1, "name": "GPT-4o (explicit true)", "description": "vision-capable, explicit YAML override", - "provider": "OPENAI", + "litellm_provider": "openai", "model_name": "gpt-4o", "api_key": "sk-test", "billing_tier": "free", @@ -36,7 +36,7 @@ _FIXTURE: list[dict] = [ "id": -2, "name": "DeepSeek V3 (explicit false)", "description": "OpenRouter dynamic — modality-derived false", - "provider": "OPENROUTER", + "litellm_provider": "openrouter", "model_name": "deepseek/deepseek-v3.2-exp", "api_key": "sk-or-test", "api_base": "https://openrouter.ai/api/v1", @@ -47,7 +47,7 @@ _FIXTURE: list[dict] = [ "id": -10_010, "name": "Unannotated GPT-4o", "description": "no flag set — resolver should derive True via LiteLLM", - "provider": "OPENAI", + "litellm_provider": "openai", "model_name": "gpt-4o", "api_key": "sk-test", "billing_tier": "free", @@ -57,7 +57,7 @@ _FIXTURE: list[dict] = [ "id": -10_011, "name": "Unannotated unknown model", "description": "unmapped — default-allow True", - "provider": "CUSTOM", + "litellm_provider": "custom", "custom_provider": "brand_new_proxy", "model_name": "brand-new-model-x9", "api_key": "sk-test", diff --git a/surfsense_backend/tests/unit/services/test_model_connections.py b/surfsense_backend/tests/unit/services/test_model_connections.py index 98042501b..797f794b1 100644 --- a/surfsense_backend/tests/unit/services/test_model_connections.py +++ b/surfsense_backend/tests/unit/services/test_model_connections.py @@ -2,11 +2,12 @@ from app.services.global_model_catalog import materialize_global_model_catalog from app.services.model_resolver import ensure_v1, to_litellm -def test_openai_compatible_resolver_normalizes_v1() -> None: +def test_openai_compatible_resolver_uses_explicit_api_base() -> None: model, kwargs = to_litellm( { "protocol": "OPENAI_COMPATIBLE", - "base_url": "http://host.docker.internal:1234", + "litellm_provider": "openai", + "base_url": "http://host.docker.internal:1234/v1", "api_key": "local-key", "extra": {}, }, @@ -23,6 +24,7 @@ def test_ollama_resolver_uses_native_api_base() -> None: model, kwargs = to_litellm( { "protocol": "OLLAMA", + "litellm_provider": "ollama_chat", "base_url": "http://host.docker.internal:11434", "api_key": None, "extra": {}, @@ -40,9 +42,10 @@ def test_global_materialization_preserves_tier_and_keeps_key_server_side() -> No { "id": -101, "name": "OpenRouter Free", - "provider": "OPENROUTER", + "litellm_provider": "openrouter", "model_name": "meta-llama/llama-3.1-8b-instruct:free", "api_key": "sk-global-secret", + "api_base": "https://openrouter.ai/api/v1", "billing_tier": "free", "anonymous_enabled": True, "seo_enabled": True, @@ -52,9 +55,10 @@ def test_global_materialization_preserves_tier_and_keeps_key_server_side() -> No { "id": -102, "name": "OpenRouter Premium", - "provider": "OPENROUTER", + "litellm_provider": "openrouter", "model_name": "anthropic/claude-sonnet-4", "api_key": "sk-global-secret", + "api_base": "https://openrouter.ai/api/v1", "billing_tier": "premium", }, ], From 8f20a3257189807136ddf56fa942683bf4b366ce Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Thu, 11 Jun 2026 18:21:07 +0530 Subject: [PATCH 018/212] refactor(model-connections): consolidate provider capability handling --- .../app/services/global_model_catalog.py | 2 +- .../openrouter_integration_service.py | 21 +--- .../app/services/pricing_registration.py | 7 +- .../app/services/provider_api_base.py | 106 ----------------- .../app/services/provider_capabilities.py | 31 +---- .../app/services/quality_score.py | 38 +++---- .../test_openrouter_integration_service.py | 12 +- .../services/test_pricing_registration.py | 17 ++- .../unit/services/test_provider_api_base.py | 107 ------------------ .../services/test_provider_capabilities.py | 28 ++--- .../tests/unit/services/test_quality_score.py | 6 +- 11 files changed, 64 insertions(+), 311 deletions(-) delete mode 100644 surfsense_backend/app/services/provider_api_base.py delete mode 100644 surfsense_backend/tests/unit/services/test_provider_api_base.py diff --git a/surfsense_backend/app/services/global_model_catalog.py b/surfsense_backend/app/services/global_model_catalog.py index a43f58b9e..e40b3a942 100644 --- a/surfsense_backend/app/services/global_model_catalog.py +++ b/surfsense_backend/app/services/global_model_catalog.py @@ -19,7 +19,7 @@ def _connection_key(conn: dict[str, Any]) -> tuple[Any, ...]: # the same provider/base can have different quota/rate limits upstream. return ( conn.get("protocol"), - conn.get("native_provider"), + conn.get("litellm_provider"), conn.get("base_url"), conn.get("api_key"), _freeze(conn.get("extra") or {}), diff --git a/surfsense_backend/app/services/openrouter_integration_service.py b/surfsense_backend/app/services/openrouter_integration_service.py index 6454e2d58..6996f0fde 100644 --- a/surfsense_backend/app/services/openrouter_integration_service.py +++ b/surfsense_backend/app/services/openrouter_integration_service.py @@ -323,10 +323,10 @@ def _generate_configs( "seo_enabled": seo_enabled, "seo_slug": None, "quota_reserve_tokens": quota_reserve_tokens, - "provider": "OPENROUTER", + "litellm_provider": "openrouter", "model_name": model_id, "api_key": api_key, - "api_base": "", + "api_base": "https://openrouter.ai/api/v1", "rpm": free_rpm if tier == "free" else rpm, "tpm": free_tpm if tier == "free" else tpm, "litellm_params": dict(litellm_params), @@ -420,14 +420,9 @@ def _generate_image_gen_configs( "id": _stable_config_id(model_id, id_offset, taken), "name": name, "description": f"{name} via OpenRouter (image generation)", - "provider": "OPENROUTER", + "litellm_provider": "openrouter", "model_name": model_id, "api_key": api_key, - # Pin to OpenRouter's public base URL so a downstream call site - # that forgets ``resolve_api_base`` still doesn't inherit - # ``AZURE_OPENAI_ENDPOINT`` and 404 on - # ``image_generation/transformation`` (defense-in-depth, see - # ``provider_api_base`` docstring). "api_base": "https://openrouter.ai/api/v1", "api_version": None, "rpm": free_rpm if tier == "free" else rpm, @@ -504,13 +499,9 @@ def _generate_vision_llm_configs( "id": _stable_config_id(model_id, id_offset, taken), "name": name, "description": f"{name} via OpenRouter (vision)", - "provider": "OPENROUTER", + "litellm_provider": "openrouter", "model_name": model_id, "api_key": api_key, - # Pin to OpenRouter's public base URL so a downstream call site - # that forgets ``resolve_api_base`` still doesn't inherit - # ``AZURE_OPENAI_ENDPOINT`` (defense-in-depth, see - # ``provider_api_base`` docstring). "api_base": "https://openrouter.ai/api/v1", "api_version": None, "rpm": free_rpm if tier == "free" else rpm, @@ -710,7 +701,7 @@ class OpenRouterIntegrationService: ) # Re-blend health scores against the freshly fetched catalogue. Also - # re-stamps health for any YAML-curated cfg with provider==OPENROUTER + # re-stamps health for any YAML-curated cfg with litellm_provider=openrouter # so a hand-picked dead OR model is gated like a dynamic one. await self._enrich_health_safely(static_configs + new_configs, log_summary=True) @@ -787,7 +778,7 @@ class OpenRouterIntegrationService: the entire previous cycle's cache for this run. """ or_cfgs = [ - c for c in configs if str(c.get("provider", "")).upper() == "OPENROUTER" + c for c in configs if str(c.get("litellm_provider", "")).lower() == "openrouter" ] if not or_cfgs: return diff --git a/surfsense_backend/app/services/pricing_registration.py b/surfsense_backend/app/services/pricing_registration.py index de98e50c2..6b99fe723 100644 --- a/surfsense_backend/app/services/pricing_registration.py +++ b/surfsense_backend/app/services/pricing_registration.py @@ -143,12 +143,12 @@ def _register_chat_shape_configs( sample_keys: list[str] = [] for cfg in configs: - provider = str(cfg.get("provider") or "").upper() + provider = str(cfg.get("litellm_provider") or "").lower() model_name = str(cfg.get("model_name") or "").strip() litellm_params = cfg.get("litellm_params") or {} base_model = str(litellm_params.get("base_model") or model_name).strip() - if provider == "OPENROUTER": + if provider == "openrouter": entry = or_pricing.get(model_name) if entry: input_cost = _safe_float(entry.get("prompt")) @@ -189,12 +189,11 @@ def _register_chat_shape_configs( skipped_no_pricing += 1 continue aliases = _alias_set_for_yaml(provider, model_name, base_model) - provider_slug = "azure" if provider == "AZURE_OPENAI" else provider.lower() count = _register( aliases, input_cost=input_cost, output_cost=output_cost, - provider=provider_slug, + provider=provider, ) if count > 0: registered_models += 1 diff --git a/surfsense_backend/app/services/provider_api_base.py b/surfsense_backend/app/services/provider_api_base.py deleted file mode 100644 index dca1f9462..000000000 --- a/surfsense_backend/app/services/provider_api_base.py +++ /dev/null @@ -1,106 +0,0 @@ -"""Provider-aware ``api_base`` resolution shared by chat / image-gen / vision. - -LiteLLM falls back to the module-global ``litellm.api_base`` when an -individual call doesn't pass one, which silently inherits provider-agnostic -env vars like ``AZURE_OPENAI_ENDPOINT`` / ``OPENAI_API_BASE``. Without an -explicit ``api_base``, an ``openrouter/`` request can end up at an -Azure endpoint and 404 with ``Resource not found`` (real reproducer: -[litellm/llms/openrouter/image_generation/transformation.py:242-263] appends -``/chat/completions`` to whatever inherited base it gets, regardless of -provider). - -The chat router has had this defense for a while -(``llm_router_service.py:466-478``). This module hoists the maps + cascade -into a tiny standalone helper so vision and image-gen can share the same -source of truth without an inter-service circular import. -""" - -from __future__ import annotations - -PROVIDER_DEFAULT_API_BASE: dict[str, str] = { - "openrouter": "https://openrouter.ai/api/v1", - "groq": "https://api.groq.com/openai/v1", - "mistral": "https://api.mistral.ai/v1", - "perplexity": "https://api.perplexity.ai", - "xai": "https://api.x.ai/v1", - "cerebras": "https://api.cerebras.ai/v1", - "deepinfra": "https://api.deepinfra.com/v1/openai", - "fireworks_ai": "https://api.fireworks.ai/inference/v1", - "together_ai": "https://api.together.xyz/v1", - "anyscale": "https://api.endpoints.anyscale.com/v1", - "cometapi": "https://api.cometapi.com/v1", - "sambanova": "https://api.sambanova.ai/v1", -} -"""Default ``api_base`` per LiteLLM provider prefix (lowercase). - -Only providers with a well-known, stable public base URL are listed — -self-hosted / BYO-endpoint providers (ollama, custom, bedrock, vertex_ai, -huggingface, databricks, cloudflare, replicate) are intentionally omitted -so their existing config-driven behaviour is preserved.""" - - -PROVIDER_KEY_DEFAULT_API_BASE: dict[str, str] = { - "DEEPSEEK": "https://api.deepseek.com/v1", - "ALIBABA_QWEN": "https://dashscope-intl.aliyuncs.com/compatible-mode/v1", - "MOONSHOT": "https://api.moonshot.ai/v1", - "ZHIPU": "https://open.bigmodel.cn/api/paas/v4", - "MINIMAX": "https://api.minimax.io/v1", -} -"""Canonical provider key (uppercase) → base URL. - -Used when the LiteLLM provider prefix is the generic ``openai`` shim but the -config's ``provider`` field tells us which API it actually is (DeepSeek, -Alibaba, Moonshot, Zhipu, MiniMax all use the ``openai`` prefix but each -has its own base URL).""" - - -def resolve_api_base( - *, - provider: str | None, - provider_prefix: str | None, - config_api_base: str | None, -) -> str | None: - """Resolve a non-Azure-leaking ``api_base`` for a deployment. - - Cascade (first non-empty wins): - 1. The config's own ``api_base`` (whitespace-only treated as missing). - 2. ``PROVIDER_KEY_DEFAULT_API_BASE[provider.upper()]``. - 3. ``PROVIDER_DEFAULT_API_BASE[provider_prefix.lower()]``. - 4. ``None`` — caller should NOT set ``api_base`` and let the LiteLLM - provider integration apply its own default (e.g. AzureOpenAI's - deployment-derived URL, custom provider's per-deployment URL). - - Args: - provider: The config's ``provider`` field (e.g. ``"OPENROUTER"``, - ``"DEEPSEEK"``). Case-insensitive. - provider_prefix: The LiteLLM model-string prefix the same call - site builds for the model id (e.g. ``"openrouter"``, - ``"groq"``). Case-insensitive. - config_api_base: ``api_base`` from the global YAML / DB row / - OpenRouter dynamic config. Empty / whitespace-only means - "missing" — the resolver still applies the cascade. - - Returns: - A URL string, or ``None`` if no default applies for this provider. - """ - if config_api_base and config_api_base.strip(): - return config_api_base - - if provider: - key_default = PROVIDER_KEY_DEFAULT_API_BASE.get(provider.upper()) - if key_default: - return key_default - - if provider_prefix: - prefix_default = PROVIDER_DEFAULT_API_BASE.get(provider_prefix.lower()) - if prefix_default: - return prefix_default - - return None - - -__all__ = [ - "PROVIDER_DEFAULT_API_BASE", - "PROVIDER_KEY_DEFAULT_API_BASE", - "resolve_api_base", -] diff --git a/surfsense_backend/app/services/provider_capabilities.py b/surfsense_backend/app/services/provider_capabilities.py index 9e1433214..9521ef7a4 100644 --- a/surfsense_backend/app/services/provider_capabilities.py +++ b/surfsense_backend/app/services/provider_capabilities.py @@ -46,26 +46,12 @@ from collections.abc import Iterable import litellm -from app.services.model_resolver import NATIVE_PROVIDER_PREFIX - logger = logging.getLogger(__name__) -# Provider-name → LiteLLM model-prefix map. -# -# Owned here because ``app.services.provider_capabilities`` is the -# only edge that's safe to call from ``app.config``'s YAML loader at -# class-body init time. ``app.agents.chat.runtime.llm_config`` re-exports -# this constant under the historical ``PROVIDER_MAP`` name; placing the -# map there directly would re-introduce the -# ``app.config -> ... -> deliverables/tools/generate_image -> -# app.config`` cycle that prompted the move. -_PROVIDER_PREFIX_MAP = NATIVE_PROVIDER_PREFIX - - def _candidate_model_strings( *, - provider: str | None, + litellm_provider: str | None, model_name: str | None, base_model: str | None, custom_provider: str | None, @@ -92,12 +78,7 @@ def _candidate_model_strings( seen.add(key) candidates.append(key) - provider_prefix: str | None = None - if provider: - provider_prefix = _PROVIDER_PREFIX_MAP.get(provider.upper(), provider.lower()) - if custom_provider: - # ``custom_provider`` overrides everything for CUSTOM/proxy setups. - provider_prefix = custom_provider + provider_prefix = custom_provider or litellm_provider primary_model = base_model or model_name bare_model = model_name @@ -132,7 +113,7 @@ def _candidate_model_strings( def derive_supports_image_input( *, - provider: str | None = None, + litellm_provider: str | None = None, model_name: str | None = None, base_model: str | None = None, custom_provider: str | None = None, @@ -166,7 +147,7 @@ def derive_supports_image_input( return False for model_string, custom_llm_provider in _candidate_model_strings( - provider=provider, + litellm_provider=litellm_provider, model_name=model_name, base_model=base_model, custom_provider=custom_provider, @@ -191,7 +172,7 @@ def derive_supports_image_input( def is_known_text_only_chat_model( *, - provider: str | None = None, + litellm_provider: str | None = None, model_name: str | None = None, base_model: str | None = None, custom_provider: str | None = None, @@ -212,7 +193,7 @@ def is_known_text_only_chat_model( leads to the regression we're fixing here. """ for model_string, custom_llm_provider in _candidate_model_strings( - provider=provider, + litellm_provider=litellm_provider, model_name=model_name, base_model=base_model, custom_provider=custom_provider, diff --git a/surfsense_backend/app/services/quality_score.py b/surfsense_backend/app/services/quality_score.py index 2fb37de21..95484439b 100644 --- a/surfsense_backend/app/services/quality_score.py +++ b/surfsense_backend/app/services/quality_score.py @@ -108,25 +108,23 @@ PROVIDER_PRESTIGE_OR: dict[str, int] = { # YAML provider field (the upstream API shape the operator selected). PROVIDER_PRESTIGE_YAML: dict[str, int] = { - "AZURE_OPENAI": 50, - "OPENAI": 50, - "ANTHROPIC": 50, - "GOOGLE": 50, - "VERTEX_AI": 50, - "GEMINI": 50, - "XAI": 50, - "MISTRAL": 38, - "DEEPSEEK": 38, - "COHERE": 38, - "GROQ": 30, - "TOGETHER_AI": 28, - "FIREWORKS_AI": 28, - "PERPLEXITY": 28, - "MINIMAX": 28, - "BEDROCK": 28, - "OPENROUTER": 25, - "OLLAMA": 12, - "CUSTOM": 12, + "azure": 50, + "openai": 50, + "anthropic": 50, + "gemini": 50, + "vertex_ai": 50, + "xai": 50, + "mistral": 38, + "deepseek": 38, + "cohere": 38, + "groq": 30, + "together_ai": 28, + "fireworks_ai": 28, + "perplexity": 28, + "bedrock": 28, + "openrouter": 25, + "ollama_chat": 12, + "custom": 12, } @@ -275,7 +273,7 @@ def static_score_yaml(cfg: dict) -> int: listed this model. Pricing / context fall through to lazy ``litellm`` lookups; failures are silent (we just lose those sub-points). """ - provider = str(cfg.get("provider", "")).upper() + provider = str(cfg.get("litellm_provider", "")).lower() base = PROVIDER_PRESTIGE_YAML.get(provider, 15) model_name = cfg.get("model_name") or "" diff --git a/surfsense_backend/tests/unit/services/test_openrouter_integration_service.py b/surfsense_backend/tests/unit/services/test_openrouter_integration_service.py index 88fcf2db3..9d4c1a04b 100644 --- a/surfsense_backend/tests/unit/services/test_openrouter_integration_service.py +++ b/surfsense_backend/tests/unit/services/test_openrouter_integration_service.py @@ -263,11 +263,10 @@ def test_generate_image_gen_configs_filters_by_image_output(): # Each config must carry ``billing_tier`` for routing in image_generation_routes. for c in cfgs: assert c["billing_tier"] in {"free", "premium"} - assert c["provider"] == "OPENROUTER" + assert c["litellm_provider"] == "openrouter" assert c[_OPENROUTER_DYNAMIC_MARKER] is True - # Defense-in-depth: emit the OpenRouter base URL at source so a - # downstream call site that forgets ``resolve_api_base`` still - # doesn't 404 against an inherited Azure endpoint. + # Emit the OpenRouter base URL at source so every call path passes an + # explicit api_base and cannot inherit a process-global endpoint. assert c["api_base"] == "https://openrouter.ai/api/v1" @@ -346,9 +345,8 @@ def test_generate_vision_llm_configs_filters_by_image_input_text_output(): assert cfg["input_cost_per_token"] == pytest.approx(5e-6) assert cfg["output_cost_per_token"] == pytest.approx(15e-6) assert cfg[_OPENROUTER_DYNAMIC_MARKER] is True - # Defense-in-depth: emit the OpenRouter base URL at source so a - # downstream call site that forgets ``resolve_api_base`` still - # doesn't inherit an Azure endpoint. + # Emit the OpenRouter base URL at source so every call path passes an + # explicit api_base and cannot inherit a process-global endpoint. assert cfg["api_base"] == "https://openrouter.ai/api/v1" diff --git a/surfsense_backend/tests/unit/services/test_pricing_registration.py b/surfsense_backend/tests/unit/services/test_pricing_registration.py index e97250ff2..c9adc6aac 100644 --- a/surfsense_backend/tests/unit/services/test_pricing_registration.py +++ b/surfsense_backend/tests/unit/services/test_pricing_registration.py @@ -186,7 +186,7 @@ def test_openrouter_models_register_under_aliases(monkeypatch): [ { "id": 1, - "provider": "OPENROUTER", + "litellm_provider": "openrouter", "model_name": "anthropic/claude-3-5-sonnet", } ], @@ -228,7 +228,7 @@ def test_yaml_override_registers_under_alias_set(monkeypatch): [ { "id": 1, - "provider": "AZURE_OPENAI", + "litellm_provider": "azure", "model_name": "gpt-5.4", "litellm_params": { "base_model": "gpt-5.4", @@ -243,7 +243,6 @@ def test_yaml_override_registers_under_alias_set(monkeypatch): keys = spy.all_keys assert "gpt-5.4" in keys - assert "azure_openai/gpt-5.4" in keys assert "azure/gpt-5.4" in keys payload = spy.calls[0] @@ -271,7 +270,7 @@ def test_no_override_means_no_registration(monkeypatch): [ { "id": 1, - "provider": "OPENAI", + "litellm_provider": "openai", "model_name": "gpt-4o", "litellm_params": {"base_model": "gpt-4o"}, } @@ -302,7 +301,7 @@ def test_openrouter_skipped_when_pricing_missing(monkeypatch): [ { "id": 1, - "provider": "OPENROUTER", + "litellm_provider": "openrouter", "model_name": "anthropic/claude-3-5-sonnet", } ], @@ -349,12 +348,12 @@ def test_register_continues_after_individual_failure(monkeypatch, caplog): [ { "id": 1, - "provider": "OPENROUTER", + "litellm_provider": "openrouter", "model_name": "anthropic/claude-3-5-sonnet", }, { "id": 2, - "provider": "OPENAI", + "litellm_provider": "openai", "model_name": "custom-deployment", "litellm_params": { "base_model": "custom-deployment", @@ -396,7 +395,7 @@ def test_vision_configs_registered_with_chat_shape(monkeypatch): [ { "id": -1, - "provider": "OPENROUTER", + "litellm_provider": "openrouter", "model_name": "openai/gpt-4o", "billing_tier": "premium", "input_cost_per_token": 5e-6, @@ -433,7 +432,7 @@ def test_vision_with_inline_pricing_when_or_cache_missing(monkeypatch): [ { "id": -1, - "provider": "OPENROUTER", + "litellm_provider": "openrouter", "model_name": "google/gemini-2.5-flash", "billing_tier": "premium", "input_cost_per_token": 1e-6, diff --git a/surfsense_backend/tests/unit/services/test_provider_api_base.py b/surfsense_backend/tests/unit/services/test_provider_api_base.py deleted file mode 100644 index 12cd0a3d5..000000000 --- a/surfsense_backend/tests/unit/services/test_provider_api_base.py +++ /dev/null @@ -1,107 +0,0 @@ -"""Unit tests for the shared ``api_base`` resolver. - -The cascade exists so vision and image-gen call sites can't silently -inherit ``litellm.api_base`` (commonly set by ``AZURE_OPENAI_ENDPOINT``) -when an OpenRouter / Groq / etc. config ships an empty string. See -``provider_api_base`` module docstring for the original repro -(OpenRouter image-gen 404-ing against an Azure endpoint). -""" - -from __future__ import annotations - -import pytest - -from app.services.provider_api_base import ( - PROVIDER_DEFAULT_API_BASE, - PROVIDER_KEY_DEFAULT_API_BASE, - resolve_api_base, -) - -pytestmark = pytest.mark.unit - - -def test_config_value_wins_over_defaults(): - """A non-empty config value is always returned verbatim, even when the - provider has a default — the operator gets the last word.""" - result = resolve_api_base( - provider="OPENROUTER", - provider_prefix="openrouter", - config_api_base="https://my-openrouter-mirror.example.com/v1", - ) - assert result == "https://my-openrouter-mirror.example.com/v1" - - -def test_provider_key_default_when_config_missing(): - """``DEEPSEEK`` shares the ``openai`` LiteLLM prefix but has its own - base URL — the provider-key map must take precedence over the prefix - map so DeepSeek requests don't go to OpenAI.""" - result = resolve_api_base( - provider="DEEPSEEK", - provider_prefix="openai", - config_api_base=None, - ) - assert result == PROVIDER_KEY_DEFAULT_API_BASE["DEEPSEEK"] - - -def test_provider_prefix_default_when_no_key_default(): - result = resolve_api_base( - provider="OPENROUTER", - provider_prefix="openrouter", - config_api_base=None, - ) - assert result == PROVIDER_DEFAULT_API_BASE["openrouter"] - - -def test_unknown_provider_returns_none(): - """When neither map matches we return ``None`` so the caller can let - LiteLLM apply its own provider-integration default (Azure deployment - URL, custom-provider URL, etc.).""" - result = resolve_api_base( - provider="SOMETHING_NEW", - provider_prefix="something_new", - config_api_base=None, - ) - assert result is None - - -def test_empty_string_config_treated_as_missing(): - """The original bug: OpenRouter dynamic configs ship ``api_base=""`` - and downstream call sites use ``if cfg.get("api_base"):`` — empty - strings are falsy in Python but the cascade has to step in anyway.""" - result = resolve_api_base( - provider="OPENROUTER", - provider_prefix="openrouter", - config_api_base="", - ) - assert result == PROVIDER_DEFAULT_API_BASE["openrouter"] - - -def test_whitespace_only_config_treated_as_missing(): - """A config value of ``" "`` is a configuration mistake — treat it - as missing instead of forwarding whitespace to LiteLLM (which would - almost certainly 404).""" - result = resolve_api_base( - provider="OPENROUTER", - provider_prefix="openrouter", - config_api_base=" ", - ) - assert result == PROVIDER_DEFAULT_API_BASE["openrouter"] - - -def test_provider_case_insensitive(): - """Some call sites pass the provider lowercase (DB enum value), others - uppercase (YAML key). Both must resolve.""" - upper = resolve_api_base( - provider="DEEPSEEK", provider_prefix="openai", config_api_base=None - ) - lower = resolve_api_base( - provider="deepseek", provider_prefix="openai", config_api_base=None - ) - assert upper == lower == PROVIDER_KEY_DEFAULT_API_BASE["DEEPSEEK"] - - -def test_all_inputs_none_returns_none(): - assert ( - resolve_api_base(provider=None, provider_prefix=None, config_api_base=None) - is None - ) diff --git a/surfsense_backend/tests/unit/services/test_provider_capabilities.py b/surfsense_backend/tests/unit/services/test_provider_capabilities.py index aac88977f..c327c2c87 100644 --- a/surfsense_backend/tests/unit/services/test_provider_capabilities.py +++ b/surfsense_backend/tests/unit/services/test_provider_capabilities.py @@ -32,7 +32,7 @@ pytestmark = pytest.mark.unit def test_or_modalities_with_image_returns_true(): assert ( derive_supports_image_input( - provider="OPENROUTER", + litellm_provider="openrouter", model_name="openai/gpt-4o", openrouter_input_modalities=["text", "image"], ) @@ -43,7 +43,7 @@ def test_or_modalities_with_image_returns_true(): def test_or_modalities_text_only_returns_false(): assert ( derive_supports_image_input( - provider="OPENROUTER", + litellm_provider="openrouter", model_name="deepseek/deepseek-v3.2-exp", openrouter_input_modalities=["text"], ) @@ -57,7 +57,7 @@ def test_or_modalities_empty_list_returns_false(): to LiteLLM.""" assert ( derive_supports_image_input( - provider="OPENROUTER", + litellm_provider="openrouter", model_name="weird/empty-modalities", openrouter_input_modalities=[], ) @@ -70,7 +70,7 @@ def test_or_modalities_none_falls_through_to_litellm(): to LiteLLM. Using ``openai/gpt-4o`` which is in LiteLLM's map.""" assert ( derive_supports_image_input( - provider="OPENAI", + litellm_provider="openai", model_name="gpt-4o", openrouter_input_modalities=None, ) @@ -86,7 +86,7 @@ def test_or_modalities_none_falls_through_to_litellm(): def test_litellm_known_vision_model_returns_true(): assert ( derive_supports_image_input( - provider="OPENAI", + litellm_provider="openai", model_name="gpt-4o", ) is True @@ -100,7 +100,7 @@ def test_litellm_base_model_wins_over_model_name(): doesn't know) would shadow the real capability.""" assert ( derive_supports_image_input( - provider="AZURE_OPENAI", + litellm_provider="azure", model_name="my-azure-deployment-id", base_model="gpt-4o", ) @@ -112,7 +112,7 @@ def test_litellm_unknown_model_default_allows(): """Default-allow on unknown — the safety net is the actual block.""" assert ( derive_supports_image_input( - provider="CUSTOM", + litellm_provider="custom", model_name="brand-new-model-x9-unmapped", custom_provider="brand_new_proxy", ) @@ -128,7 +128,7 @@ def test_litellm_known_text_only_returns_false(): # Sanity: confirm the helper's negative path. We use a small model # known not to support vision per the map. result = derive_supports_image_input( - provider="DEEPSEEK", + litellm_provider="openai", model_name="deepseek-chat", ) # We accept either False (LiteLLM said explicit no) or True @@ -147,7 +147,7 @@ def test_litellm_known_text_only_returns_false(): def test_is_known_text_only_returns_false_for_vision_model(): assert ( is_known_text_only_chat_model( - provider="OPENAI", + litellm_provider="openai", model_name="gpt-4o", ) is False @@ -160,7 +160,7 @@ def test_is_known_text_only_returns_false_for_unknown_model(): fixing.""" assert ( is_known_text_only_chat_model( - provider="CUSTOM", + litellm_provider="custom", model_name="brand-new-model-x9-unmapped", custom_provider="brand_new_proxy", ) @@ -181,7 +181,7 @@ def test_is_known_text_only_returns_false_when_lookup_raises(monkeypatch): assert ( is_known_text_only_chat_model( - provider="OPENAI", + litellm_provider="openai", model_name="gpt-4o", ) is False @@ -201,7 +201,7 @@ def test_is_known_text_only_returns_true_on_explicit_false(monkeypatch): assert ( is_known_text_only_chat_model( - provider="OPENAI", + litellm_provider="openai", model_name="any-model", ) is True @@ -218,7 +218,7 @@ def test_is_known_text_only_returns_false_on_supports_vision_true(monkeypatch): assert ( is_known_text_only_chat_model( - provider="OPENAI", + litellm_provider="openai", model_name="any-model", ) is False @@ -237,7 +237,7 @@ def test_is_known_text_only_returns_false_on_missing_key(monkeypatch): assert ( is_known_text_only_chat_model( - provider="OPENAI", + litellm_provider="openai", model_name="any-model", ) is False diff --git a/surfsense_backend/tests/unit/services/test_quality_score.py b/surfsense_backend/tests/unit/services/test_quality_score.py index 6fbc8fd62..369c8b8f3 100644 --- a/surfsense_backend/tests/unit/services/test_quality_score.py +++ b/surfsense_backend/tests/unit/services/test_quality_score.py @@ -228,7 +228,7 @@ def test_static_score_or_recent_release_beats_year_old_same_provider(): def test_static_score_yaml_includes_operator_bonus(): cfg = { - "provider": "AZURE_OPENAI", + "litellm_provider": "azure", "model_name": "gpt-5", "litellm_params": {"base_model": "azure/gpt-5"}, } @@ -238,7 +238,7 @@ def test_static_score_yaml_includes_operator_bonus(): def test_static_score_yaml_unknown_provider_still_carries_bonus(): cfg = { - "provider": "SOME_NEW_PROVIDER", + "litellm_provider": "some_new_provider", "model_name": "weird-model", } score = static_score_yaml(cfg) @@ -247,7 +247,7 @@ def test_static_score_yaml_unknown_provider_still_carries_bonus(): def test_static_score_yaml_clamped_0_to_100(): cfg = { - "provider": "AZURE_OPENAI", + "litellm_provider": "azure", "model_name": "gpt-5", "litellm_params": {"base_model": "azure/gpt-5"}, } From c28c4f5785bb0bbdfe8997ef157f8fedc3276d72 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Thu, 11 Jun 2026 18:22:23 +0530 Subject: [PATCH 019/212] feat(chat): route models by provider capabilities --- .../app/agents/chat/runtime/llm_config.py | 53 ++--- .../app/routes/anonymous_chat_routes.py | 4 +- .../app/routes/vision_llm_routes.py | 2 +- .../app/services/auto_model_pin_service.py | 206 ++++++++++++++---- .../app/services/llm_router_service.py | 10 +- surfsense_backend/app/services/llm_service.py | 75 +++---- .../app/services/model_list_service.py | 10 +- .../app/services/model_resolver.py | 110 +++------- .../app/services/vision_llm_router_service.py | 14 +- .../app/services/vision_model_list_service.py | 8 +- .../flows/new_chat/llm_capability.py | 2 +- .../streaming/flows/new_chat/title_gen.py | 17 +- .../chat/streaming/flows/shared/llm_bundle.py | 137 ++++++++++-- .../services/test_auto_model_pin_service.py | 68 +++--- .../services/test_auto_pin_image_aware.py | 6 +- .../services/test_llm_router_pool_filter.py | 8 +- .../services/test_or_health_enrichment.py | 6 +- .../test_stream_new_chat_image_safety_net.py | 12 +- 18 files changed, 429 insertions(+), 319 deletions(-) diff --git a/surfsense_backend/app/agents/chat/runtime/llm_config.py b/surfsense_backend/app/agents/chat/runtime/llm_config.py index 03d7f548e..b9344e001 100644 --- a/surfsense_backend/app/agents/chat/runtime/llm_config.py +++ b/surfsense_backend/app/agents/chat/runtime/llm_config.py @@ -2,9 +2,9 @@ LLM configuration utilities for SurfSense agents. This module provides functions for loading LLM configurations from: -1. Auto mode (ID 0) - Uses LiteLLM Router for load balancing +1. Auto mode (ID 0) - Resolved by callers to a concrete model-connection model 2. YAML files (global configs with negative IDs) -3. Database NewLLMConfig table (user-created configs with positive IDs) +3. Database model-connections table (user-created configs with positive IDs) It also provides utilities for creating ChatLiteLLM instances and managing prompt configurations. @@ -33,9 +33,7 @@ from app.agents.chat.runtime.prompt_caching import ( from app.services.llm_router_service import ( AUTO_MODE_ID, ChatLiteLLMRouter, - LLMRouterService, _sanitize_content, - get_auto_mode_llm, is_auto_mode, ) @@ -92,14 +90,6 @@ class SanitizedChatLiteLLM(ChatLiteLLM): yield chunk -# Re-exported under the historical name ``PROVIDER_MAP``. Source of truth lives -# in provider_capabilities so the YAML loader can resolve prefixes during -# app.config init without importing the agent/tools tree. -from app.services.provider_capabilities import ( # noqa: E402 - _PROVIDER_PREFIX_MAP as PROVIDER_MAP, -) - - def _attach_model_profile(llm: ChatLiteLLM, model_string: str) -> None: """Attach a ``profile`` dict to ChatLiteLLM with model context metadata.""" try: @@ -122,7 +112,8 @@ class AgentConfig: Complete configuration for the SurfSense agent. This combines LLM settings with prompt configuration from NewLLMConfig. - Supports Auto mode (ID 0) which uses LiteLLM Router for load balancing. + Supports Auto mode metadata (ID 0). Runtime callers must resolve Auto to + a concrete global or BYOK model before constructing ChatLiteLLM. """ # LLM Model Settings @@ -219,7 +210,7 @@ class AgentConfig: # BYOK rows have no curated flag; ask LiteLLM (default-allow on # unknown). The streaming safety net still blocks explicit text-only. supports_image_input=derive_supports_image_input( - provider=provider_value, + litellm_provider=provider_value.lower(), model_name=config.model_name, base_model=base_model, custom_provider=config.custom_provider, @@ -238,7 +229,7 @@ class AgentConfig: system_instructions = yaml_config.get("system_instructions", "") - provider = yaml_config.get("provider", "").upper() + provider = yaml_config.get("litellm_provider", "") model_name = yaml_config.get("model_name", "") custom_provider = yaml_config.get("custom_provider") litellm_params = yaml_config.get("litellm_params") or {} @@ -254,7 +245,7 @@ class AgentConfig: supports_image_input = bool(yaml_config.get("supports_image_input")) else: supports_image_input = derive_supports_image_input( - provider=provider, + litellm_provider=provider, model_name=model_name, base_model=base_model, custom_provider=custom_provider, @@ -383,9 +374,6 @@ async def load_agent_config( ) -> "AgentConfig | None": """Main config loader: id 0 -> Auto mode; negative -> YAML; positive -> DB.""" if is_auto_mode(config_id): - if not LLMRouterService.is_initialized(): - print("Error: Auto mode requested but LLM Router not initialized") - return None return AgentConfig.from_auto_mode() if config_id < 0: @@ -408,9 +396,8 @@ def create_chat_litellm_from_config(llm_config: dict) -> ChatLiteLLM | None: if llm_config.get("custom_provider"): model_string = f"{llm_config['custom_provider']}/{llm_config['model_name']}" else: - provider = llm_config.get("provider", "").upper() - provider_prefix = PROVIDER_MAP.get(provider, provider.lower()) - model_string = f"{provider_prefix}/{llm_config['model_name']}" + litellm_provider = llm_config.get("litellm_provider", "openai") + model_string = f"{litellm_provider}/{llm_config['model_name']}" litellm_kwargs = { "model": model_string, @@ -433,29 +420,15 @@ def create_chat_litellm_from_config(llm_config: dict) -> ChatLiteLLM | None: def create_chat_litellm_from_agent_config( agent_config: AgentConfig, ) -> ChatLiteLLM | ChatLiteLLMRouter | None: - """Create a ChatLiteLLM (or, for Auto mode, a load-balancing router) from config.""" + """Create a ChatLiteLLM from an already resolved concrete model config.""" if agent_config.is_auto_mode: - if not LLMRouterService.is_initialized(): - print("Error: Auto mode requested but LLM Router not initialized") - return None - try: - router_llm = get_auto_mode_llm() - if router_llm is not None: - # Universal injection points only: auto-mode fans out across - # providers, so provider-specific kwargs have no known target. - apply_litellm_prompt_caching(router_llm, agent_config=agent_config) - return router_llm - except Exception as e: - print(f"Error creating ChatLiteLLMRouter: {e}") - return None + print("Error: Auto mode must be resolved to a concrete model before LLM creation") + return None if agent_config.custom_provider: model_string = f"{agent_config.custom_provider}/{agent_config.model_name}" else: - provider_prefix = PROVIDER_MAP.get( - agent_config.provider, agent_config.provider.lower() - ) - model_string = f"{provider_prefix}/{agent_config.model_name}" + model_string = f"{agent_config.provider}/{agent_config.model_name}" litellm_kwargs = { "model": model_string, diff --git a/surfsense_backend/app/routes/anonymous_chat_routes.py b/surfsense_backend/app/routes/anonymous_chat_routes.py index ad3277375..aba1a3a12 100644 --- a/surfsense_backend/app/routes/anonymous_chat_routes.py +++ b/surfsense_backend/app/routes/anonymous_chat_routes.py @@ -132,7 +132,7 @@ async def list_anonymous_models(): id=cfg.get("id", 0), name=cfg.get("name", ""), description=cfg.get("description"), - provider=cfg.get("provider", ""), + provider=cfg.get("litellm_provider", ""), model_name=cfg.get("model_name", ""), billing_tier=cfg.get("billing_tier", "free"), is_premium=cfg.get("billing_tier", "free") == "premium", @@ -161,7 +161,7 @@ async def get_anonymous_model(slug: str): id=cfg.get("id", 0), name=cfg.get("name", ""), description=cfg.get("description"), - provider=cfg.get("provider", ""), + provider=cfg.get("litellm_provider", ""), model_name=cfg.get("model_name", ""), billing_tier=cfg.get("billing_tier", "free"), is_premium=cfg.get("billing_tier", "free") == "premium", diff --git a/surfsense_backend/app/routes/vision_llm_routes.py b/surfsense_backend/app/routes/vision_llm_routes.py index e4f08f604..df218daac 100644 --- a/surfsense_backend/app/routes/vision_llm_routes.py +++ b/surfsense_backend/app/routes/vision_llm_routes.py @@ -96,7 +96,7 @@ async def get_global_vision_llm_configs( "id": cfg.get("id"), "name": cfg.get("name"), "description": cfg.get("description"), - "provider": cfg.get("provider"), + "provider": cfg.get("litellm_provider"), "custom_provider": cfg.get("custom_provider"), "model_name": cfg.get("model_name"), "api_base": cfg.get("api_base") or None, diff --git a/surfsense_backend/app/services/auto_model_pin_service.py b/surfsense_backend/app/services/auto_model_pin_service.py index 9bbca8669..ee8c4b8dc 100644 --- a/surfsense_backend/app/services/auto_model_pin_service.py +++ b/surfsense_backend/app/services/auto_model_pin_service.py @@ -23,9 +23,10 @@ from uuid import UUID from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import selectinload from app.config import config -from app.db import NewChatThread +from app.db import Connection, Model, NewChatThread from app.services.quality_score import _QUALITY_TOP_K from app.services.token_quota_service import TokenQuotaService @@ -61,11 +62,20 @@ def _is_usable_global_config(cfg: dict) -> bool: return bool( cfg.get("id") is not None and cfg.get("model_name") - and cfg.get("provider") + and cfg.get("litellm_provider") and cfg.get("api_key") ) +def _has_capability(model: dict | Model, capability: str) -> bool: + caps = ( + model.get("capabilities", {}) + if isinstance(model, dict) + else model.capabilities or {} + ) + return bool(caps.get(capability)) + + def _prune_runtime_cooldowns(now_ts: float | None = None) -> None: now = time.time() if now_ts is None else now_ts stale = [cid for cid, until in _runtime_cooldown_until.items() if until <= now] @@ -186,15 +196,19 @@ def _cfg_supports_image_input(cfg: dict) -> bool: else None ) return derive_supports_image_input( - provider=cfg.get("provider"), + litellm_provider=cfg.get("litellm_provider"), model_name=cfg.get("model_name"), base_model=base_model, custom_provider=cfg.get("custom_provider"), ) -def _global_candidates(*, requires_image_input: bool = False) -> list[dict]: - """Return Auto-eligible global cfgs. +def _global_candidates( + *, + capability: str = "chat", + requires_image_input: bool = False, +) -> list[dict]: + """Return Auto-eligible global virtual models. Drops cfgs flagged ``health_gated`` (best non-null OpenRouter uptime below ``_HEALTH_GATE_UPTIME_PCT``) so chronically broken providers @@ -205,17 +219,135 @@ def _global_candidates(*, requires_image_input: bool = False) -> list[dict]: filters out configs whose ``supports_image_input`` resolves to False so a text-only deployment can't be pinned for an image request. """ - candidates = [ - cfg + connection_by_id = { + int(conn.get("id")): conn + for conn in config.GLOBAL_CONNECTIONS + if conn.get("id") is not None + } + config_by_model_name = { + cfg.get("model_name"): cfg for cfg in config.GLOBAL_LLM_CONFIGS if _is_usable_global_config(cfg) - and not cfg.get("health_gated") - and not _is_runtime_cooled_down(int(cfg.get("id", 0))) - and (not requires_image_input or _cfg_supports_image_input(cfg)) - ] + } + candidates: list[dict] = [] + for model in config.GLOBAL_MODELS: + model_id = int(model.get("id", 0)) + if model_id >= 0 or _is_runtime_cooled_down(model_id): + continue + if not _has_capability(model, capability): + continue + cfg = config_by_model_name.get(model.get("model_id")) or {} + if cfg.get("health_gated"): + continue + if requires_image_input and not _has_capability(model, "vision"): + continue + if requires_image_input and cfg and not _cfg_supports_image_input(cfg): + continue + connection = connection_by_id.get(int(model.get("connection_id", 0))) + if not connection: + continue + catalog = model.get("catalog") or {} + candidates.append( + { + "id": model_id, + "model_id": model.get("model_id"), + "source": "global", + "connection": connection, + "capabilities": model.get("capabilities") or {}, + "billing_tier": model.get("billing_tier", "free"), + "litellm_provider": connection.get("litellm_provider"), + "model_name": model.get("model_id"), + "auto_pin_tier": catalog.get("auto_pin_tier") + or cfg.get("auto_pin_tier") + or "A", + "quality_score": catalog.get("quality_score") + or cfg.get("quality_score") + or cfg.get("quality_score_static") + or 50, + } + ) return sorted(candidates, key=lambda c: int(c.get("id", 0))) +async def _db_candidates( + session: AsyncSession, + *, + search_space_id: int, + user_id: str | UUID | None, + capability: str, + requires_image_input: bool = False, +) -> list[dict]: + parsed_user_id = _to_uuid(user_id) + stmt = ( + select(Model) + .options(selectinload(Model.connection)) + .join(Connection, Model.connection_id == Connection.id) + .where(Model.enabled.is_(True), Connection.enabled.is_(True)) + ) + result = await session.execute(stmt) + candidates: list[dict] = [] + for model in result.scalars().all(): + conn = model.connection + if not conn: + continue + if conn.search_space_id is not None and conn.search_space_id != search_space_id: + continue + if conn.user_id is not None and parsed_user_id is not None and conn.user_id != parsed_user_id: + continue + if conn.user_id is not None and parsed_user_id is None: + continue + if not _has_capability(model, capability): + continue + if requires_image_input and not _has_capability(model, "vision"): + continue + model_id = int(model.id) + if _is_runtime_cooled_down(model_id): + continue + catalog = model.catalog or {} + candidates.append( + { + "id": model_id, + "model_id": model.model_id, + "source": "db", + "connection": conn, + "capabilities": model.capabilities or {}, + "billing_tier": "byok", + "litellm_provider": conn.litellm_provider, + "model_name": model.model_id, + "auto_pin_tier": catalog.get("auto_pin_tier") or "BYOK", + "quality_score": catalog.get("quality_score") or 75, + } + ) + return sorted(candidates, key=lambda c: int(c.get("id", 0))) + + +async def auto_model_candidates( + session: AsyncSession, + *, + search_space_id: int, + user_id: str | UUID | None, + capability: str, + requires_image_input: bool = False, + exclude_model_ids: set[int] | None = None, +) -> list[dict]: + excluded_ids = {int(mid) for mid in (exclude_model_ids or set())} + db_candidates = await _db_candidates( + session, + search_space_id=search_space_id, + user_id=user_id, + capability=capability, + requires_image_input=requires_image_input, + ) + candidates = [ + *_global_candidates( + capability=capability, + requires_image_input=requires_image_input, + ), + *db_candidates, + ] + return [c for c in candidates if int(c.get("id", 0)) not in excluded_ids] + + def _tier_of(cfg: dict) -> str: return str(cfg.get("billing_tier", "free")).lower() @@ -223,8 +355,9 @@ def _tier_of(cfg: dict) -> str: def _is_preferred_premium_auto_config(cfg: dict) -> bool: """Return True for the operator-preferred premium Auto model.""" return ( - _tier_of(cfg) == "premium" - and str(cfg.get("provider", "")).upper() == "AZURE_OPENAI" + cfg.get("source") == "global" + and _tier_of(cfg) == "premium" + and str(cfg.get("litellm_provider", "")).lower() == "azure" and str(cfg.get("model_name", "")).lower() == "gpt-5.4" ) @@ -251,6 +384,11 @@ def _select_pin(eligible: list[dict], thread_id: int) -> tuple[dict, int]: return top_k[idx], len(top_k) +def choose_auto_model_candidate(candidates: list[dict], seed_id: int) -> dict: + selected, _ = _select_pin(candidates, seed_id) + return selected + + def _to_uuid(user_id: str | UUID | None) -> UUID | None: if user_id is None: return None @@ -326,20 +464,23 @@ async def resolve_or_get_pinned_llm_config_id( ) excluded_ids = {int(cid) for cid in (exclude_config_ids or set())} - candidates = [ - c - for c in _global_candidates(requires_image_input=requires_image_input) - if int(c.get("id", 0)) not in excluded_ids - ] + candidates = await auto_model_candidates( + session, + search_space_id=search_space_id, + user_id=user_id, + capability="chat", + requires_image_input=requires_image_input, + exclude_model_ids=excluded_ids, + ) if not candidates: if requires_image_input: # Distinguish the "no vision-capable cfg" case from generic # "no usable cfg" so the streaming task can map this to the # MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT SSE error. raise ValueError( - "No vision-capable global LLM configs are available for Auto mode" + "No vision-capable LLM models are available for Auto mode" ) - raise ValueError("No usable global LLM configs are available for Auto mode") + raise ValueError("No usable LLM models are available for Auto mode") candidate_by_id = {int(c["id"]): c for c in candidates} # Reuse an existing valid pin without re-checking current quota (no silent @@ -379,24 +520,13 @@ async def resolve_or_get_pinned_llm_config_id( # log that explicitly so operators can correlate the re-pin with # the user's image attachment instead of suspecting a cooldown. if requires_image_input: - try: - pinned_global = next( - c - for c in config.GLOBAL_LLM_CONFIGS - if int(c.get("id", 0)) == int(pinned_id) - ) - except StopIteration: - pinned_global = None - if pinned_global is not None and not _cfg_supports_image_input( - pinned_global - ): - logger.info( - "auto_pin_repinned_for_image thread_id=%s search_space_id=%s " - "previous_config_id=%s", - thread_id, - search_space_id, - pinned_id, - ) + logger.info( + "auto_pin_repinned_for_image thread_id=%s search_space_id=%s " + "previous_config_id=%s", + thread_id, + search_space_id, + pinned_id, + ) logger.info( "auto_pin_invalid thread_id=%s search_space_id=%s pinned_config_id=%s", thread_id, diff --git a/surfsense_backend/app/services/llm_router_service.py b/surfsense_backend/app/services/llm_router_service.py index 69feb30eb..a151a0d6e 100644 --- a/surfsense_backend/app/services/llm_router_service.py +++ b/surfsense_backend/app/services/llm_router_service.py @@ -30,11 +30,7 @@ from litellm.exceptions import ( ) from pydantic import Field -from app.services.model_resolver import ( - NATIVE_PROVIDER_PREFIX, - native_connection_from_config, - to_litellm, -) +from app.services.model_resolver import native_connection_from_config, to_litellm from app.utils.perf import get_perf_logger litellm.json_logs = False @@ -101,10 +97,6 @@ def _sanitize_content(content: Any) -> Any: # Special ID for Auto mode - uses router for load balancing AUTO_MODE_ID = 0 -# Historical export kept for callers that still import PROVIDER_MAP. -PROVIDER_MAP = NATIVE_PROVIDER_PREFIX - - class LLMRouterService: """ Singleton service for managing LiteLLM Router. diff --git a/surfsense_backend/app/services/llm_service.py b/surfsense_backend/app/services/llm_service.py index 75451d01f..86a9c8556 100644 --- a/surfsense_backend/app/services/llm_service.py +++ b/surfsense_backend/app/services/llm_service.py @@ -10,13 +10,11 @@ from sqlalchemy.orm import selectinload from app.config import config from app.db import Model, SearchSpace -from app.services.llm_router_service import ( - AUTO_MODE_ID, - ChatLiteLLMRouter, - LLMRouterService, - get_auto_mode_llm, - is_auto_mode, +from app.services.auto_model_pin_service import ( + auto_model_candidates, + choose_auto_model_candidate, ) +from app.services.llm_router_service import AUTO_MODE_ID, ChatLiteLLMRouter, is_auto_mode from app.services.model_resolver import native_connection_from_config, to_litellm from app.services.token_tracking_service import token_tracker @@ -78,7 +76,7 @@ def _legacy_config_connection( api_version: str | None = None, ) -> tuple[str, dict]: cfg = { - "provider": provider, + "litellm_provider": provider.lower(), "model_name": model_name, "api_key": api_key, "api_base": api_base, @@ -325,23 +323,21 @@ async def get_search_space_llm_instance( logger.error(f"No {role} LLM configured for search space {search_space_id}") return None - # Check for Auto mode (ID 0) - use router for load balancing + # Auto mode resolves to one concrete global or BYOK model from the + # unified model-connections catalog. if is_auto_mode(llm_config_id): - if not LLMRouterService.is_initialized(): - logger.error( - "Auto mode requested but LLM Router not initialized. " - "Ensure global_llm_config.yaml exists with valid configs." - ) - return None - - try: - logger.debug( - f"Using Auto mode (LLM Router) for search space {search_space_id}, role {role}" - ) - return get_auto_mode_llm(streaming=not disable_streaming) - except Exception as e: - logger.error(f"Failed to create ChatLiteLLMRouter: {e}") + candidates = await auto_model_candidates( + session, + search_space_id=search_space_id, + user_id=search_space.user_id, + capability="chat", + ) + if not candidates: + logger.error("No chat-capable models available for Auto mode") return None + llm_config_id = int( + choose_auto_model_candidate(candidates, search_space_id)["id"] + ) # Check if this is a global virtual model (negative ID) if llm_config_id < 0: @@ -414,7 +410,7 @@ async def get_vision_llm( """Get the search space's vision LLM instance for screenshot analysis. Resolves from the new connection/model role bindings: - - Auto mode (ID 0): VisionLLMRouterService + - Auto mode (ID 0): unified global/BYOK model candidate selection - Global (negative ID): virtual GLOBAL models from YAML - DB (positive ID): Model + Connection tables @@ -424,10 +420,7 @@ async def get_vision_llm( unwrapped — they don't consume premium credit (issue M). """ from app.services.quota_checked_vision_llm import QuotaCheckedVisionLLM - from app.services.vision_llm_router_service import ( - VisionLLMRouterService, - is_vision_auto_mode, - ) + from app.services.vision_llm_router_service import is_vision_auto_mode try: result = await session.execute( @@ -476,26 +469,16 @@ async def get_vision_llm( return None if is_vision_auto_mode(config_id): - if not VisionLLMRouterService.is_initialized(): - logger.error( - "Vision Auto mode requested but Vision LLM Router not initialized" - ) - return None - try: - # Auto mode is currently treated as free at the wrapper - # level — the underlying router can dispatch to either - # premium or free YAML configs but routing decisions are - # opaque. If/when we want to bill Auto-routed vision - # calls we'd need to thread the resolved deployment's - # billing_tier back from the router. For now we keep - # parity with chat Auto, which also doesn't pre-classify. - return ChatLiteLLMRouter( - router=VisionLLMRouterService.get_router(), - streaming=True, - ) - except Exception as e: - logger.error(f"Failed to create vision ChatLiteLLMRouter: {e}") + candidates = await auto_model_candidates( + session, + search_space_id=search_space_id, + user_id=owner_user_id, + capability="vision", + ) + if not candidates: + logger.error("No vision-capable models available for Auto mode") return None + config_id = int(choose_auto_model_candidate(candidates, search_space_id)["id"]) if config_id < 0: global_model = get_global_model(config_id) diff --git a/surfsense_backend/app/services/model_list_service.py b/surfsense_backend/app/services/model_list_service.py index 33837a8a0..1ef0b0c90 100644 --- a/surfsense_backend/app/services/model_list_service.py +++ b/surfsense_backend/app/services/model_list_service.py @@ -154,19 +154,19 @@ def _process_models(raw_models: list[dict]) -> list[dict]: } ) - # 2) Emit for the native provider when we have a mapping - native_provider = OPENROUTER_SLUG_TO_PROVIDER.get(provider_slug) - if native_provider: + # 2) Emit for the direct provider when we have a mapping + direct_provider = OPENROUTER_SLUG_TO_PROVIDER.get(provider_slug) + if direct_provider: # Google's Gemini API only serves gemini-* models. # Open-source models like gemma-* are NOT available through it. - if native_provider == "GOOGLE" and not model_name.startswith("gemini-"): + if direct_provider == "GOOGLE" and not model_name.startswith("gemini-"): continue processed.append( { "value": model_name, "label": name, - "provider": native_provider, + "provider": direct_provider, "context_window": context_window, } ) diff --git a/surfsense_backend/app/services/model_resolver.py b/surfsense_backend/app/services/model_resolver.py index ec485a5ae..ffa77a9a2 100644 --- a/surfsense_backend/app/services/model_resolver.py +++ b/surfsense_backend/app/services/model_resolver.py @@ -9,53 +9,12 @@ from __future__ import annotations from collections.abc import Mapping from typing import TYPE_CHECKING, Any -from app.services.provider_api_base import resolve_api_base - if TYPE_CHECKING: from app.db import Connection PROTOCOL_OLLAMA = "OLLAMA" PROTOCOL_OPENAI_COMPATIBLE = "OPENAI_COMPATIBLE" -PROTOCOL_NATIVE = "NATIVE" - -NATIVE_PROVIDER_PREFIX: dict[str, str] = { - "OPENAI": "openai", - "ANTHROPIC": "anthropic", - "GROQ": "groq", - "COHERE": "cohere", - "GOOGLE": "gemini", - "MISTRAL": "mistral", - "AZURE_OPENAI": "azure", - "AZURE": "azure", - "OPENROUTER": "openrouter", - "COMETAPI": "cometapi", - "XAI": "xai", - "BEDROCK": "bedrock", - "AWS_BEDROCK": "bedrock", - "VERTEX_AI": "vertex_ai", - "TOGETHER_AI": "together_ai", - "FIREWORKS_AI": "fireworks_ai", - "DEEPSEEK": "openai", - "ALIBABA_QWEN": "openai", - "MOONSHOT": "openai", - "ZHIPU": "openai", - "GITHUB_MODELS": "github", - "REPLICATE": "replicate", - "PERPLEXITY": "perplexity", - "ANYSCALE": "anyscale", - "DEEPINFRA": "deepinfra", - "CEREBRAS": "cerebras", - "SAMBANOVA": "sambanova", - "AI21": "ai21", - "CLOUDFLARE": "cloudflare", - "DATABRICKS": "databricks", - "HUGGINGFACE": "huggingface", - "MINIMAX": "openai", - "RECRAFT": "recraft", - "XINFERENCE": "xinference", - "NSCALE": "nscale", - "CUSTOM": "custom", -} +PROTOCOL_ANTHROPIC = "ANTHROPIC" def ensure_v1(base_url: str | None) -> str | None: @@ -77,6 +36,23 @@ def _protocol_value(protocol: Any) -> str: return getattr(protocol, "value", str(protocol)) +def default_litellm_provider(protocol: Any) -> str: + protocol_value = _protocol_value(protocol) + defaults = { + PROTOCOL_OLLAMA: "ollama_chat", + PROTOCOL_ANTHROPIC: "anthropic", + PROTOCOL_OPENAI_COMPATIBLE: "openai", + } + return defaults.get(protocol_value, "openai") + + +def _execution_api_base(protocol: str, base_url: str | None) -> str | None: + del protocol + if not base_url: + return None + return base_url.rstrip("/") + + def to_litellm( conn: Connection | Mapping[str, Any], model_id: str, @@ -85,38 +61,19 @@ def to_litellm( protocol = _protocol_value(_conn_value(conn, "protocol")) base_url = _conn_value(conn, "base_url") api_key = _conn_value(conn, "api_key") - native_provider = _conn_value(conn, "native_provider") + litellm_provider = ( + _conn_value(conn, "litellm_provider") or default_litellm_provider(protocol) + ) extra = _conn_value(conn, "extra") or {} kwargs: dict[str, Any] = {} if api_key: kwargs["api_key"] = api_key - if protocol == PROTOCOL_OLLAMA: - model_string = f"ollama_chat/{model_id}" - if base_url: - kwargs["api_base"] = base_url.rstrip("/") - elif protocol == PROTOCOL_OPENAI_COMPATIBLE: - model_string = f"openai/{model_id}" - api_base = ensure_v1(base_url) - if api_base: - kwargs["api_base"] = api_base - else: - provider_key = (native_provider or "").upper() - prefix = NATIVE_PROVIDER_PREFIX.get(provider_key, provider_key.lower()) - if prefix == "custom": - custom_provider = extra.get("custom_provider") or native_provider - model_string = f"{custom_provider}/{model_id}" if custom_provider else model_id - else: - model_string = f"{prefix}/{model_id}" - - api_base = resolve_api_base( - provider=provider_key, - provider_prefix=prefix, - config_api_base=base_url, - ) - if api_base: - kwargs["api_base"] = api_base + model_string = f"{litellm_provider}/{model_id}" if litellm_provider else model_id + api_base = _execution_api_base(protocol, base_url) + if api_base: + kwargs["api_base"] = api_base if api_version := extra.get("api_version"): kwargs["api_version"] = api_version @@ -126,18 +83,21 @@ def to_litellm( def native_connection_from_config(config: Mapping[str, Any]) -> dict[str, Any]: - """Build an in-memory NATIVE connection mapping from a legacy/global config.""" - provider = str(config.get("provider") or config.get("custom_provider") or "CUSTOM") + """Build an in-memory connection mapping from a global config.""" + protocol = str(config.get("protocol") or PROTOCOL_OPENAI_COMPATIBLE) + litellm_provider = str( + config.get("litellm_provider") + or config.get("custom_provider") + or default_litellm_provider(protocol) + ) extra: dict[str, Any] = { "litellm_params": config.get("litellm_params") or {}, } if config.get("api_version"): extra["api_version"] = config.get("api_version") - if config.get("custom_provider"): - extra["custom_provider"] = config.get("custom_provider") return { - "protocol": PROTOCOL_NATIVE, - "native_provider": provider, + "protocol": protocol, + "litellm_provider": litellm_provider, "base_url": config.get("api_base") or None, "api_key": config.get("api_key") or None, "extra": extra, @@ -145,7 +105,7 @@ def native_connection_from_config(config: Mapping[str, Any]) -> dict[str, Any]: __all__ = [ - "NATIVE_PROVIDER_PREFIX", + "default_litellm_provider", "ensure_v1", "native_connection_from_config", "to_litellm", diff --git a/surfsense_backend/app/services/vision_llm_router_service.py b/surfsense_backend/app/services/vision_llm_router_service.py index 0c7182ecf..0ff716324 100644 --- a/surfsense_backend/app/services/vision_llm_router_service.py +++ b/surfsense_backend/app/services/vision_llm_router_service.py @@ -3,19 +3,12 @@ from typing import Any from litellm import Router -from app.services.model_resolver import ( - NATIVE_PROVIDER_PREFIX, - native_connection_from_config, - to_litellm, -) +from app.services.model_resolver import native_connection_from_config, to_litellm logger = logging.getLogger(__name__) VISION_AUTO_MODE_ID = 0 -VISION_PROVIDER_MAP = NATIVE_PROVIDER_PREFIX - - class VisionLLMRouterService: _instance = None _router: Router | None = None @@ -141,12 +134,11 @@ def is_vision_auto_mode(config_id: int | None) -> bool: def build_vision_model_string( - provider: str, model_name: str, custom_provider: str | None + litellm_provider: str, model_name: str, custom_provider: str | None ) -> str: if custom_provider: return f"{custom_provider}/{model_name}" - prefix = VISION_PROVIDER_MAP.get(provider.upper(), provider.lower()) - return f"{prefix}/{model_name}" + return f"{litellm_provider}/{model_name}" def get_global_vision_llm_config(config_id: int) -> dict | None: diff --git a/surfsense_backend/app/services/vision_model_list_service.py b/surfsense_backend/app/services/vision_model_list_service.py index fc459910b..6eae8c455 100644 --- a/surfsense_backend/app/services/vision_model_list_service.py +++ b/surfsense_backend/app/services/vision_model_list_service.py @@ -97,16 +97,16 @@ def _process_vision_models(raw_models: list[dict]) -> list[dict]: } ) - native_provider = OPENROUTER_SLUG_TO_VISION_PROVIDER.get(provider_slug) - if native_provider: - if native_provider == "GOOGLE" and not model_name.startswith("gemini-"): + direct_provider = OPENROUTER_SLUG_TO_VISION_PROVIDER.get(provider_slug) + if direct_provider: + if direct_provider == "GOOGLE" and not model_name.startswith("gemini-"): continue processed.append( { "value": model_name, "label": name, - "provider": native_provider, + "provider": direct_provider, "context_window": context_window, } ) diff --git a/surfsense_backend/app/tasks/chat/streaming/flows/new_chat/llm_capability.py b/surfsense_backend/app/tasks/chat/streaming/flows/new_chat/llm_capability.py index 69b9f4ab8..f6fcf75d7 100644 --- a/surfsense_backend/app/tasks/chat/streaming/flows/new_chat/llm_capability.py +++ b/surfsense_backend/app/tasks/chat/streaming/flows/new_chat/llm_capability.py @@ -40,7 +40,7 @@ def check_image_input_capability( else None ) if not is_known_text_only_chat_model( - provider=agent_config.provider, + litellm_provider=agent_config.provider, model_name=agent_config.model_name, base_model=agent_base_model, custom_provider=agent_config.custom_provider, diff --git a/surfsense_backend/app/tasks/chat/streaming/flows/new_chat/title_gen.py b/surfsense_backend/app/tasks/chat/streaming/flows/new_chat/title_gen.py index fe3d210bb..d5e8c3729 100644 --- a/surfsense_backend/app/tasks/chat/streaming/flows/new_chat/title_gen.py +++ b/surfsense_backend/app/tasks/chat/streaming/flows/new_chat/title_gen.py @@ -80,7 +80,6 @@ async def _generate_title( from litellm import acompletion from app.services.llm_router_service import LLMRouterService - from app.services.provider_api_base import resolve_api_base from app.services.token_tracking_service import _turn_accumulator # Excludes this turn's own assistant row (pre-written by @@ -125,26 +124,12 @@ async def _generate_title( router = LLMRouterService.get_router() response = await router.acompletion(model="auto", messages=messages) else: - # Apply the same ``api_base`` cascade chat / vision / image-gen - # call sites use so we never inherit ``litellm.api_base`` - # (commonly set by ``AZURE_OPENAI_ENDPOINT``) when the chat - # config itself ships an empty ``api_base``. Without this the - # title-gen on an OpenRouter chat config would 404 against the - # inherited Azure endpoint — see ``provider_api_base`` for the - # same bug repro on the image-gen / vision paths. raw_model = getattr(llm, "model", "") or "" - provider_prefix = raw_model.split("/", 1)[0] if "/" in raw_model else None - provider_value = agent_config.provider if agent_config is not None else None - title_api_base = resolve_api_base( - provider=provider_value, - provider_prefix=provider_prefix, - config_api_base=getattr(llm, "api_base", None), - ) response = await acompletion( model=raw_model, messages=messages, api_key=getattr(llm, "api_key", None), - api_base=title_api_base, + api_base=getattr(llm, "api_base", None), ) usage_info = None diff --git a/surfsense_backend/app/tasks/chat/streaming/flows/shared/llm_bundle.py b/surfsense_backend/app/tasks/chat/streaming/flows/shared/llm_bundle.py index 7e2bc950b..f6870f5fa 100644 --- a/surfsense_backend/app/tasks/chat/streaming/flows/shared/llm_bundle.py +++ b/surfsense_backend/app/tasks/chat/streaming/flows/shared/llm_bundle.py @@ -1,8 +1,8 @@ """Load an LLM + AgentConfig bundle for a given config id. Handles both code paths uniformly: -- ``config_id >= 0`` → database-backed ``NewLLMConfig`` row (per-user/per-space). -- ``config_id < 0`` → YAML-defined global LLM config (built-in defaults). +- ``config_id > 0`` → database-backed model-connection ``Model`` row. +- ``config_id < 0`` → virtual global model materialized from YAML/OpenRouter. Returns ``(llm, agent_config, error_message)``; on success ``error_message`` is ``None``. The caller emits the friendly SSE error frame. @@ -12,15 +12,72 @@ from __future__ import annotations from typing import Any +from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import selectinload from app.agents.chat.runtime.llm_config import ( AgentConfig, - create_chat_litellm_from_agent_config, - create_chat_litellm_from_config, - load_agent_config, - load_global_llm_config_by_id, + SanitizedChatLiteLLM, ) +from app.config import config +from app.db import Model, SearchSpace +from app.services.model_resolver import to_litellm + + +def _agent_config_from_resolved( + *, + config_id: int, + config_name: str | None, + provider: str, + model_name: str, + api_key: str | None, + api_base: str | None, + litellm_params: dict | None, + supports_image_input: bool, + billing_tier: str = "free", +) -> AgentConfig: + return AgentConfig( + provider=provider, + model_name=model_name, + api_key=api_key or "", + api_base=api_base, + custom_provider=None, + litellm_params=litellm_params, + config_id=config_id, + config_name=config_name, + is_auto_mode=False, + billing_tier=billing_tier, + is_premium=billing_tier == "premium", + supports_image_input=supports_image_input, + ) + + +async def _load_search_space(session: AsyncSession, search_space_id: int) -> SearchSpace | None: + result = await session.execute(select(SearchSpace).where(SearchSpace.id == search_space_id)) + return result.scalars().first() + + +async def _load_db_model( + session: AsyncSession, + *, + model_id: int, + search_space: SearchSpace, +) -> Model | None: + result = await session.execute( + select(Model) + .options(selectinload(Model.connection)) + .where(Model.id == model_id, Model.enabled.is_(True)) + ) + model = result.scalars().first() + if not model or not model.connection or not model.connection.enabled: + return None + conn = model.connection + if conn.search_space_id is not None and conn.search_space_id != search_space.id: + return None + if conn.user_id is not None and conn.user_id != search_space.user_id: + return None + return model async def load_llm_bundle( @@ -29,29 +86,67 @@ async def load_llm_bundle( config_id: int, search_space_id: int, ) -> tuple[Any, AgentConfig | None, str | None]: - if config_id >= 0: - loaded_agent_config = await load_agent_config( - session=session, - config_id=config_id, - search_space_id=search_space_id, + search_space = await _load_search_space(session, search_space_id) + if not search_space: + return None, None, f"Search space {search_space_id} not found" + + if config_id > 0: + model = await _load_db_model( + session, + model_id=config_id, + search_space=search_space, ) - if not loaded_agent_config: + if not model or not (model.capabilities or {}).get("chat"): return ( None, None, - f"Failed to load NewLLMConfig with id {config_id}", + f"Failed to load chat model with id {config_id}", ) + model_string, litellm_kwargs = to_litellm(model.connection, model.model_id) + agent_config = _agent_config_from_resolved( + config_id=config_id, + config_name=model.display_name or model.model_id, + provider=model.connection.litellm_provider or "", + model_name=model.model_id, + api_key=model.connection.api_key, + api_base=model.connection.base_url, + litellm_params=(model.connection.extra or {}).get("litellm_params"), + supports_image_input=bool((model.capabilities or {}).get("vision")), + billing_tier="free", + ) return ( - create_chat_litellm_from_agent_config(loaded_agent_config), - loaded_agent_config, + SanitizedChatLiteLLM(model=model_string, **litellm_kwargs), + agent_config, None, ) - loaded_llm_config = load_global_llm_config_by_id(config_id) - if not loaded_llm_config: - return None, None, f"Failed to load LLM config with id {config_id}" - return ( - create_chat_litellm_from_config(loaded_llm_config), - AgentConfig.from_yaml_config(loaded_llm_config), + global_model = next((m for m in config.GLOBAL_MODELS if m.get("id") == config_id), None) + if not global_model or not (global_model.get("capabilities") or {}).get("chat"): + return None, None, f"Failed to load global chat model with id {config_id}" + global_connection = next( + ( + c + for c in config.GLOBAL_CONNECTIONS + if c.get("id") == global_model.get("connection_id") + ), + None, + ) + if not global_connection: + return None, None, f"Failed to load global connection for model {config_id}" + model_string, litellm_kwargs = to_litellm(global_connection, global_model["model_id"]) + agent_config = _agent_config_from_resolved( + config_id=config_id, + config_name=global_model.get("display_name") or global_model.get("model_id"), + provider=global_connection.get("litellm_provider") or "", + model_name=global_model["model_id"], + api_key=global_connection.get("api_key"), + api_base=global_connection.get("base_url"), + litellm_params=(global_connection.get("extra") or {}).get("litellm_params"), + supports_image_input=bool((global_model.get("capabilities") or {}).get("vision")), + billing_tier=str(global_model.get("billing_tier", "free")).lower(), + ) + return ( + SanitizedChatLiteLLM(model=model_string, **litellm_kwargs), + agent_config, None, ) diff --git a/surfsense_backend/tests/unit/services/test_auto_model_pin_service.py b/surfsense_backend/tests/unit/services/test_auto_model_pin_service.py index d1af29aeb..0af41a7ee 100644 --- a/surfsense_backend/tests/unit/services/test_auto_model_pin_service.py +++ b/surfsense_backend/tests/unit/services/test_auto_model_pin_service.py @@ -75,10 +75,10 @@ async def test_auto_first_turn_pins_one_model(monkeypatch): config, "GLOBAL_LLM_CONFIGS", [ - {"id": -2, "provider": "OPENAI", "model_name": "gpt-free", "api_key": "k1"}, + {"id": -2, "litellm_provider": "openai", "model_name": "gpt-free", "api_key": "k1"}, { "id": -1, - "provider": "OPENAI", + "litellm_provider": "openai", "model_name": "gpt-prem", "api_key": "k2", "billing_tier": "premium", @@ -117,7 +117,7 @@ async def test_premium_eligible_auto_prefers_premium_over_free(monkeypatch): [ { "id": -2, - "provider": "OPENAI", + "litellm_provider": "openai", "model_name": "gpt-free", "api_key": "k1", "billing_tier": "free", @@ -125,7 +125,7 @@ async def test_premium_eligible_auto_prefers_premium_over_free(monkeypatch): }, { "id": -1, - "provider": "OPENAI", + "litellm_provider": "openai", "model_name": "gpt-prem", "api_key": "k2", "billing_tier": "premium", @@ -164,7 +164,7 @@ async def test_premium_eligible_auto_prefers_azure_gpt_5_4(monkeypatch): [ { "id": -1, - "provider": "AZURE_OPENAI", + "litellm_provider": "azure", "model_name": "gpt-5.1", "api_key": "k1", "billing_tier": "premium", @@ -173,7 +173,7 @@ async def test_premium_eligible_auto_prefers_azure_gpt_5_4(monkeypatch): }, { "id": -2, - "provider": "AZURE_OPENAI", + "litellm_provider": "azure", "model_name": "gpt-5.4", "api_key": "k2", "billing_tier": "premium", @@ -182,7 +182,7 @@ async def test_premium_eligible_auto_prefers_azure_gpt_5_4(monkeypatch): }, { "id": -3, - "provider": "OPENROUTER", + "litellm_provider": "openrouter", "model_name": "openai/gpt-5.4", "api_key": "k3", "billing_tier": "premium", @@ -222,7 +222,7 @@ async def test_next_turn_reuses_existing_pin(monkeypatch): [ { "id": -1, - "provider": "OPENAI", + "litellm_provider": "openai", "model_name": "gpt-prem", "api_key": "k2", "billing_tier": "premium", @@ -263,7 +263,7 @@ async def test_premium_eligible_auto_can_pin_premium(monkeypatch): [ { "id": -1, - "provider": "OPENAI", + "litellm_provider": "openai", "model_name": "gpt-prem", "api_key": "k2", "billing_tier": "premium", @@ -301,14 +301,14 @@ async def test_premium_ineligible_auto_pins_free_only(monkeypatch): [ { "id": -2, - "provider": "OPENAI", + "litellm_provider": "openai", "model_name": "gpt-free", "api_key": "k1", "billing_tier": "free", }, { "id": -1, - "provider": "OPENAI", + "litellm_provider": "openai", "model_name": "gpt-prem", "api_key": "k2", "billing_tier": "premium", @@ -346,14 +346,14 @@ async def test_pinned_premium_stays_premium_after_quota_exhaustion(monkeypatch): [ { "id": -2, - "provider": "OPENAI", + "litellm_provider": "openai", "model_name": "gpt-free", "api_key": "k1", "billing_tier": "free", }, { "id": -1, - "provider": "OPENAI", + "litellm_provider": "openai", "model_name": "gpt-prem", "api_key": "k2", "billing_tier": "premium", @@ -391,14 +391,14 @@ async def test_force_repin_free_switches_auto_premium_pin_to_free(monkeypatch): [ { "id": -2, - "provider": "OPENAI", + "litellm_provider": "openai", "model_name": "gpt-free", "api_key": "k1", "billing_tier": "free", }, { "id": -1, - "provider": "OPENAI", + "litellm_provider": "openai", "model_name": "gpt-prem", "api_key": "k2", "billing_tier": "premium", @@ -437,7 +437,7 @@ async def test_explicit_user_model_change_clears_pin(monkeypatch): config, "GLOBAL_LLM_CONFIGS", [ - {"id": -2, "provider": "OPENAI", "model_name": "gpt-free", "api_key": "k1"}, + {"id": -2, "litellm_provider": "openai", "model_name": "gpt-free", "api_key": "k1"}, ], ) @@ -462,7 +462,7 @@ async def test_invalid_pinned_config_repairs_with_new_pin(monkeypatch): config, "GLOBAL_LLM_CONFIGS", [ - {"id": -2, "provider": "OPENAI", "model_name": "gpt-free", "api_key": "k1"}, + {"id": -2, "litellm_provider": "openai", "model_name": "gpt-free", "api_key": "k1"}, ], ) @@ -504,7 +504,7 @@ async def test_health_gated_config_is_excluded_from_selection(monkeypatch): [ { "id": -1, - "provider": "OPENROUTER", + "litellm_provider": "openrouter", "model_name": "venice/dead-model", "api_key": "k1", "billing_tier": "free", @@ -514,7 +514,7 @@ async def test_health_gated_config_is_excluded_from_selection(monkeypatch): }, { "id": -2, - "provider": "OPENROUTER", + "litellm_provider": "openrouter", "model_name": "google/gemini-flash", "api_key": "k1", "billing_tier": "free", @@ -556,7 +556,7 @@ async def test_tier_a_locks_first_premium_user_skips_or(monkeypatch): [ { "id": -1, - "provider": "AZURE_OPENAI", + "litellm_provider": "azure", "model_name": "gpt-5", "api_key": "k-yaml", "billing_tier": "premium", @@ -566,7 +566,7 @@ async def test_tier_a_locks_first_premium_user_skips_or(monkeypatch): }, { "id": -2, - "provider": "OPENROUTER", + "litellm_provider": "openrouter", "model_name": "openai/gpt-5", "api_key": "k-or", "billing_tier": "premium", @@ -608,7 +608,7 @@ async def test_tier_a_falls_through_to_or_when_a_pool_empty_for_user(monkeypatch [ { "id": -1, - "provider": "AZURE_OPENAI", + "litellm_provider": "azure", "model_name": "gpt-5", "api_key": "k-yaml", "billing_tier": "premium", @@ -618,7 +618,7 @@ async def test_tier_a_falls_through_to_or_when_a_pool_empty_for_user(monkeypatch }, { "id": -2, - "provider": "OPENROUTER", + "litellm_provider": "openrouter", "model_name": "google/gemini-flash:free", "api_key": "k-or", "billing_tier": "free", @@ -656,7 +656,7 @@ async def test_top_k_picks_only_high_score_models(monkeypatch): high_score_cfgs = [ { "id": -i, - "provider": "AZURE_OPENAI", + "litellm_provider": "azure", "model_name": f"gpt-x-{i}", "api_key": "k", "billing_tier": "premium", @@ -668,7 +668,7 @@ async def test_top_k_picks_only_high_score_models(monkeypatch): ] low_score_trap = { "id": -99, - "provider": "AZURE_OPENAI", + "litellm_provider": "azure", "model_name": "tiny-legacy", "api_key": "k", "billing_tier": "premium", @@ -729,7 +729,7 @@ async def test_pin_reuse_survives_health_gating_for_existing_pin(monkeypatch): [ { "id": -1, - "provider": "OPENROUTER", + "litellm_provider": "openrouter", "model_name": "venice/dead-model", "api_key": "k", "billing_tier": "premium", @@ -739,7 +739,7 @@ async def test_pin_reuse_survives_health_gating_for_existing_pin(monkeypatch): }, { "id": -2, - "provider": "AZURE_OPENAI", + "litellm_provider": "azure", "model_name": "gpt-5", "api_key": "k", "billing_tier": "premium", @@ -781,7 +781,7 @@ async def test_pin_reuse_regression_existing_healthy_pin(monkeypatch): [ { "id": -1, - "provider": "AZURE_OPENAI", + "litellm_provider": "azure", "model_name": "gpt-5", "api_key": "k", "billing_tier": "premium", @@ -791,7 +791,7 @@ async def test_pin_reuse_regression_existing_healthy_pin(monkeypatch): }, { "id": -2, - "provider": "AZURE_OPENAI", + "litellm_provider": "azure", "model_name": "gpt-5-pro", "api_key": "k", "billing_tier": "premium", @@ -839,7 +839,7 @@ async def test_runtime_cooled_down_pin_is_not_reused(monkeypatch): [ { "id": -1, - "provider": "OPENROUTER", + "litellm_provider": "openrouter", "model_name": "google/gemma-4-26b-a4b-it:free", "api_key": "k", "billing_tier": "free", @@ -849,7 +849,7 @@ async def test_runtime_cooled_down_pin_is_not_reused(monkeypatch): }, { "id": -2, - "provider": "OPENROUTER", + "litellm_provider": "openrouter", "model_name": "google/gemini-2.5-flash:free", "api_key": "k", "billing_tier": "free", @@ -892,7 +892,7 @@ async def test_clearing_runtime_cooldown_restores_pin_reuse(monkeypatch): [ { "id": -1, - "provider": "OPENROUTER", + "litellm_provider": "openrouter", "model_name": "google/gemma-4-26b-a4b-it:free", "api_key": "k", "billing_tier": "free", @@ -937,7 +937,7 @@ async def test_auto_pin_repin_excludes_previous_config_on_runtime_retry(monkeypa [ { "id": -1, - "provider": "OPENROUTER", + "litellm_provider": "openrouter", "model_name": "google/gemma-4-26b-a4b-it:free", "api_key": "k", "billing_tier": "free", @@ -947,7 +947,7 @@ async def test_auto_pin_repin_excludes_previous_config_on_runtime_retry(monkeypa }, { "id": -2, - "provider": "OPENROUTER", + "litellm_provider": "openrouter", "model_name": "google/gemini-2.5-flash:free", "api_key": "k", "billing_tier": "free", diff --git a/surfsense_backend/tests/unit/services/test_auto_pin_image_aware.py b/surfsense_backend/tests/unit/services/test_auto_pin_image_aware.py index 0e19b80e4..e267d59ba 100644 --- a/surfsense_backend/tests/unit/services/test_auto_pin_image_aware.py +++ b/surfsense_backend/tests/unit/services/test_auto_pin_image_aware.py @@ -74,7 +74,7 @@ def _thread(*, pinned: int | None = None): def _vision_cfg(id_: int, *, tier: str = "free", quality: int = 80) -> dict: return { "id": id_, - "provider": "OPENAI", + "litellm_provider": "openai", "model_name": f"vision-{id_}", "api_key": "k", "billing_tier": tier, @@ -87,7 +87,7 @@ def _vision_cfg(id_: int, *, tier: str = "free", quality: int = 80) -> dict: def _text_only_cfg(id_: int, *, tier: str = "free", quality: int = 90) -> dict: return { "id": id_, - "provider": "OPENAI", + "litellm_provider": "openai", "model_name": f"text-{id_}", "api_key": "k", "billing_tier": tier, @@ -261,7 +261,7 @@ async def test_image_turn_unannotated_cfg_resolves_via_helper(monkeypatch): session = _FakeSession(_thread()) cfg_unannotated_vision = { "id": -2, - "provider": "OPENAI", + "litellm_provider": "openai", "model_name": "gpt-4o", # known vision model in LiteLLM map "api_key": "k", "billing_tier": "free", diff --git a/surfsense_backend/tests/unit/services/test_llm_router_pool_filter.py b/surfsense_backend/tests/unit/services/test_llm_router_pool_filter.py index c309ff881..efe906ac0 100644 --- a/surfsense_backend/tests/unit/services/test_llm_router_pool_filter.py +++ b/surfsense_backend/tests/unit/services/test_llm_router_pool_filter.py @@ -25,10 +25,10 @@ def _fake_yaml_config( return { "id": id, "name": f"yaml-{id}", - "provider": "OPENAI", + "litellm_provider": "openai", "model_name": model_name, "api_key": "sk-test", - "api_base": "", + "api_base": "https://api.openai.com/v1", "billing_tier": billing_tier, "rpm": 100, "tpm": 100_000, @@ -54,10 +54,10 @@ def _fake_openrouter_config( return { "id": id, "name": f"or-{id}", - "provider": "OPENROUTER", + "litellm_provider": "openrouter", "model_name": model_name, "api_key": "sk-or-test", - "api_base": "", + "api_base": "https://openrouter.ai/api/v1", "billing_tier": billing_tier, "rpm": 20 if billing_tier == "free" else 200, "tpm": 100_000 if billing_tier == "free" else 1_000_000, diff --git a/surfsense_backend/tests/unit/services/test_or_health_enrichment.py b/surfsense_backend/tests/unit/services/test_or_health_enrichment.py index 1c74aa928..b4b6618a4 100644 --- a/surfsense_backend/tests/unit/services/test_or_health_enrichment.py +++ b/surfsense_backend/tests/unit/services/test_or_health_enrichment.py @@ -25,7 +25,7 @@ def _or_cfg( ) -> dict: return { "id": cid, - "provider": "OPENROUTER", + "litellm_provider": "openrouter", "model_name": model_name, "billing_tier": tier, "auto_pin_tier": "B" if tier == "premium" else "C", @@ -144,7 +144,7 @@ async def test_enrich_health_only_touches_or_provider(monkeypatch): """YAML cfgs that aren't OPENROUTER must be skipped entirely.""" yaml_cfg = { "id": -1, - "provider": "AZURE_OPENAI", + "litellm_provider": "azure", "model_name": "gpt-5", "billing_tier": "premium", "auto_pin_tier": "A", @@ -313,7 +313,7 @@ async def test_enrich_health_no_or_cfgs_is_noop(monkeypatch): """When the catalogue has no OR cfgs at all, no HTTP calls fire.""" yaml_cfg: dict[str, Any] = { "id": -1, - "provider": "AZURE_OPENAI", + "litellm_provider": "azure", "model_name": "gpt-5", "billing_tier": "premium", } diff --git a/surfsense_backend/tests/unit/tasks/test_stream_new_chat_image_safety_net.py b/surfsense_backend/tests/unit/tasks/test_stream_new_chat_image_safety_net.py index 792d059b0..6bfc72bf3 100644 --- a/surfsense_backend/tests/unit/tasks/test_stream_new_chat_image_safety_net.py +++ b/surfsense_backend/tests/unit/tasks/test_stream_new_chat_image_safety_net.py @@ -35,7 +35,7 @@ def test_safety_net_does_not_fire_for_azure_gpt_4o(): it text-only.""" assert ( is_known_text_only_chat_model( - provider="AZURE_OPENAI", + litellm_provider="azure", model_name="my-azure-deployment", base_model="gpt-4o", ) @@ -49,7 +49,7 @@ def test_safety_net_does_not_fire_for_unknown_model(): LiteLLM doesn't know about must flow through to the provider.""" assert ( is_known_text_only_chat_model( - provider="CUSTOM", + litellm_provider="custom", custom_provider="brand_new_proxy", model_name="brand-new-model-x9", ) @@ -69,7 +69,7 @@ def test_safety_net_does_not_fire_when_lookup_raises(monkeypatch): assert ( is_known_text_only_chat_model( - provider="OPENAI", + litellm_provider="openai", model_name="gpt-4o", ) is False @@ -88,7 +88,7 @@ def test_safety_net_fires_only_on_explicit_false(monkeypatch): monkeypatch.setattr(pc.litellm, "get_model_info", _info_explicit_false) assert ( is_known_text_only_chat_model( - provider="OPENAI", + litellm_provider="openai", model_name="text-only-stub", ) is True @@ -100,7 +100,7 @@ def test_safety_net_fires_only_on_explicit_false(monkeypatch): monkeypatch.setattr(pc.litellm, "get_model_info", _info_true) assert ( is_known_text_only_chat_model( - provider="OPENAI", + litellm_provider="openai", model_name="vision-stub", ) is False @@ -112,7 +112,7 @@ def test_safety_net_fires_only_on_explicit_false(monkeypatch): monkeypatch.setattr(pc.litellm, "get_model_info", _info_missing) assert ( is_known_text_only_chat_model( - provider="OPENAI", + litellm_provider="openai", model_name="missing-key-stub", ) is False From 831ad23c6c8f72d4f731290240cb38c7d44dda31 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Thu, 11 Jun 2026 18:22:45 +0530 Subject: [PATCH 020/212] fix(chat): harden image generation model routing --- .../deliverables/tools/generate_image.py | 90 ++++++++++----- .../app/routes/image_generation_routes.py | 107 ++++++++++-------- .../app/services/image_gen_router_service.py | 12 +- .../scripts/verify_chat_image_capability.py | 35 ++---- .../tests/unit/routes/test_image_gen_quota.py | 6 +- .../test_image_gen_api_base_defense.py | 51 +++------ .../test_vision_llm_api_base_defense.py | 26 ++--- 7 files changed, 156 insertions(+), 171 deletions(-) diff --git a/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/deliverables/tools/generate_image.py b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/deliverables/tools/generate_image.py index dd980c51c..fda327750 100644 --- a/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/deliverables/tools/generate_image.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/deliverables/tools/generate_image.py @@ -23,20 +23,26 @@ from app.db import ( ) from app.services.image_gen_router_service import ( IMAGE_GEN_AUTO_MODE_ID, - ImageGenRouterService, is_image_gen_auto_mode, ) -from app.services.model_resolver import native_connection_from_config, to_litellm +from app.services.auto_model_pin_service import ( + auto_model_candidates, + choose_auto_model_candidate, +) +from app.services.model_resolver import to_litellm from app.utils.signed_image_urls import generate_image_token logger = logging.getLogger(__name__) -def _get_global_image_gen_config(config_id: int) -> dict | None: - """Get a global image gen config by negative ID.""" - for cfg in config.GLOBAL_IMAGE_GEN_CONFIGS: - if cfg.get("id") == config_id: - return cfg - return None +def _get_global_model(model_id: int) -> dict | None: + return next((m for m in config.GLOBAL_MODELS if m.get("id") == model_id), None) + + +def _get_global_connection(connection_id: int) -> dict | None: + return next( + (c for c in config.GLOBAL_CONNECTIONS if c.get("id") == connection_id), + None, + ) def create_generate_image_tool( @@ -93,6 +99,16 @@ def create_generate_image_tool( # task's session is shared across every tool; without isolation, # autoflushes from a concurrent writer poison this tool too. async with shielded_async_session() as session: + result = await session.execute( + select(SearchSpace).filter(SearchSpace.id == search_space_id) + ) + search_space = result.scalars().first() + if not search_space: + return _failed( + {"error": "Search space not found"}, + error="Search space not found", + ) + if image_gen_model_id_override is not None: # Automation run: use the captured image model, insulated from # later search-space changes. No search-space read needed. @@ -100,16 +116,6 @@ def create_generate_image_tool( image_gen_model_id_override or IMAGE_GEN_AUTO_MODE_ID ) else: - result = await session.execute( - select(SearchSpace).filter(SearchSpace.id == search_space_id) - ) - search_space = result.scalars().first() - if not search_space: - return _failed( - {"error": "Search space not found"}, - error="Search space not found", - ) - config_id = ( search_space.image_gen_model_id or IMAGE_GEN_AUTO_MODE_ID @@ -122,24 +128,39 @@ def create_generate_image_tool( gen_kwargs["n"] = n if is_image_gen_auto_mode(config_id): - if not ImageGenRouterService.is_initialized(): + candidates = await auto_model_candidates( + session, + search_space_id=search_space_id, + user_id=search_space.user_id, + capability="image_gen", + ) + if not candidates: err = ( - "No image generation models configured. " + "No image generation models available. " "Please add an image model in Settings > Image Models." ) return _failed({"error": err}, error=err) - response = await ImageGenRouterService.aimage_generation( - prompt=prompt, model="auto", **gen_kwargs + config_id = int( + choose_auto_model_candidate(candidates, search_space_id)["id"] ) - elif config_id < 0: - cfg = _get_global_image_gen_config(config_id) - if not cfg: - err = f"Image generation config {config_id} not found" + + if config_id < 0: + global_model = _get_global_model(config_id) + if not global_model or not ( + global_model.get("capabilities") or {} + ).get("image_gen"): + err = f"Image generation model {config_id} not found" + return _failed({"error": err}, error=err) + global_connection = _get_global_connection( + global_model["connection_id"] + ) + if not global_connection: + err = f"Image generation connection for model {config_id} not found" return _failed({"error": err}, error=err) model_string, resolved_kwargs = to_litellm( - native_connection_from_config(cfg), - cfg["model_name"], + global_connection, + global_model["model_id"], ) gen_kwargs.update(resolved_kwargs) @@ -157,6 +178,19 @@ def create_generate_image_tool( if not db_model or not db_model.connection or not db_model.connection.enabled: err = f"Image generation model {config_id} not found" return _failed({"error": err}, error=err) + conn = db_model.connection + if ( + conn.search_space_id is not None + and conn.search_space_id != search_space_id + ): + err = f"Image generation model {config_id} not found" + return _failed({"error": err}, error=err) + if ( + conn.user_id is not None + and conn.user_id != search_space.user_id + ): + err = f"Image generation model {config_id} not found" + return _failed({"error": err}, error=err) if not (db_model.capabilities or {}).get("image_gen"): err = f"Model {config_id} is not image-generation capable" return _failed({"error": err}, error=err) diff --git a/surfsense_backend/app/routes/image_generation_routes.py b/surfsense_backend/app/routes/image_generation_routes.py index 0de368d57..5be1cedf1 100644 --- a/surfsense_backend/app/routes/image_generation_routes.py +++ b/surfsense_backend/app/routes/image_generation_routes.py @@ -45,10 +45,13 @@ from app.services.billable_calls import ( ) from app.services.image_gen_router_service import ( IMAGE_GEN_AUTO_MODE_ID, - ImageGenRouterService, is_image_gen_auto_mode, ) -from app.services.model_resolver import native_connection_from_config, to_litellm +from app.services.auto_model_pin_service import ( + auto_model_candidates, + choose_auto_model_candidate, +) +from app.services.model_resolver import to_litellm from app.users import current_active_user from app.utils.rbac import check_permission from app.utils.signed_image_urls import verify_image_token @@ -56,22 +59,15 @@ from app.utils.signed_image_urls import verify_image_token router = APIRouter() logger = logging.getLogger(__name__) -def _get_global_image_gen_config(config_id: int) -> dict | None: - """Get a global image generation configuration by ID (negative IDs).""" - if config_id == IMAGE_GEN_AUTO_MODE_ID: - return { - "id": IMAGE_GEN_AUTO_MODE_ID, - "name": "Auto (Fastest)", - "provider": "AUTO", - "model_name": "auto", - "is_auto_mode": True, - } - if config_id > 0: - return None - for cfg in config.GLOBAL_IMAGE_GEN_CONFIGS: - if cfg.get("id") == config_id: - return cfg - return None +def _get_global_model(model_id: int) -> dict | None: + return next((m for m in config.GLOBAL_MODELS if m.get("id") == model_id), None) + + +def _get_global_connection(connection_id: int) -> dict | None: + return next( + (c for c in config.GLOBAL_CONNECTIONS if c.get("id") == connection_id), + None, + ) async def _resolve_billing_for_image_gen( @@ -87,30 +83,41 @@ async def _resolve_billing_for_image_gen( config that will actually run, and so we don't open an ``ImageGeneration`` row for a request that's about to 402. - User-owned (positive ID) BYOK configs are always free — they cost - the user nothing on our side. Auto mode currently treats as free - because the underlying router can dispatch to either premium or - free YAML configs and we don't surface the resolved deployment up - here yet. Bringing Auto under premium billing would require - threading the chosen deployment back from ``ImageGenRouterService``. + User-owned (positive ID) BYOK models are always free — they cost + the user nothing on our side. Auto mode resolves to one concrete + global or BYOK model before billing is calculated. """ resolved_id = config_id if resolved_id is None: resolved_id = search_space.image_gen_model_id or IMAGE_GEN_AUTO_MODE_ID if is_image_gen_auto_mode(resolved_id): - return ("free", "auto", DEFAULT_IMAGE_RESERVE_MICROS) + candidates = await auto_model_candidates( + session, + search_space_id=search_space.id, + user_id=search_space.user_id, + capability="image_gen", + ) + if not candidates: + return ("free", "auto", DEFAULT_IMAGE_RESERVE_MICROS) + selected = choose_auto_model_candidate(candidates, search_space.id) + resolved_id = int(selected["id"]) if resolved_id < 0: - cfg = _get_global_image_gen_config(resolved_id) or {} - billing_tier = str(cfg.get("billing_tier", "free")).lower() - base_model, _ = to_litellm(native_connection_from_config(cfg), cfg.get("model_name", "")) + global_model = _get_global_model(resolved_id) or {} + global_connection = _get_global_connection(global_model.get("connection_id", 0)) + billing_tier = str(global_model.get("billing_tier", "free")).lower() + if global_connection and global_model.get("model_id"): + base_model, _ = to_litellm(global_connection, global_model["model_id"]) + else: + base_model = "global_image_model" + catalog = global_model.get("catalog") or {} reserve_micros = int( - cfg.get("quota_reserve_micros") or DEFAULT_IMAGE_RESERVE_MICROS + catalog.get("quota_reserve_micros") or DEFAULT_IMAGE_RESERVE_MICROS ) return (billing_tier, base_model, reserve_micros) - # Positive ID = user-owned BYOK image-gen config — always free. + # Positive ID = user-owned BYOK image-gen model — always free. return ("free", "user_byok", DEFAULT_IMAGE_RESERVE_MICROS) @@ -146,23 +153,28 @@ async def _execute_image_generation( gen_kwargs["response_format"] = image_gen.response_format if is_image_gen_auto_mode(config_id): - if not ImageGenRouterService.is_initialized(): - raise ValueError( - "Auto mode requested but Image Generation Router not initialized. " - "Ensure global_llm_config.yaml has global_image_generation_configs." - ) - response = await ImageGenRouterService.aimage_generation( - prompt=image_gen.prompt, model="auto", **gen_kwargs + candidates = await auto_model_candidates( + session, + search_space_id=search_space.id, + user_id=search_space.user_id, + capability="image_gen", ) - elif config_id < 0: - # Global config from YAML - cfg = _get_global_image_gen_config(config_id) - if not cfg: - raise ValueError(f"Global image generation config {config_id} not found") + if not candidates: + raise ValueError("No image-generation models are available for Auto mode") + config_id = int(choose_auto_model_candidate(candidates, search_space.id)["id"]) + image_gen.image_generation_config_id = config_id + + if config_id < 0: + global_model = _get_global_model(config_id) + if not global_model or not (global_model.get("capabilities") or {}).get("image_gen"): + raise ValueError(f"Global image generation model {config_id} not found") + global_connection = _get_global_connection(global_model["connection_id"]) + if not global_connection: + raise ValueError(f"Global connection for image model {config_id} not found") model_string, resolved_kwargs = to_litellm( - native_connection_from_config(cfg), - cfg["model_name"], + global_connection, + global_model["model_id"], ) gen_kwargs.update(resolved_kwargs) @@ -183,6 +195,11 @@ async def _execute_image_generation( db_model = result.scalars().first() if not db_model or not db_model.connection or not db_model.connection.enabled: raise ValueError(f"Image generation model {config_id} not found") + conn = db_model.connection + if conn.search_space_id is not None and conn.search_space_id != search_space.id: + raise ValueError(f"Image generation model {config_id} not found") + if conn.user_id is not None and conn.user_id != search_space.user_id: + raise ValueError(f"Image generation model {config_id} not found") if not (db_model.capabilities or {}).get("image_gen"): raise ValueError(f"Model {config_id} is not image-generation capable") @@ -255,7 +272,7 @@ async def get_global_image_gen_configs( "id": cfg.get("id"), "name": cfg.get("name"), "description": cfg.get("description"), - "provider": cfg.get("provider"), + "provider": cfg.get("litellm_provider"), "custom_provider": cfg.get("custom_provider"), "model_name": cfg.get("model_name"), "api_base": cfg.get("api_base") or None, diff --git a/surfsense_backend/app/services/image_gen_router_service.py b/surfsense_backend/app/services/image_gen_router_service.py index 0b03f5c6d..3716537b9 100644 --- a/surfsense_backend/app/services/image_gen_router_service.py +++ b/surfsense_backend/app/services/image_gen_router_service.py @@ -20,23 +20,13 @@ from typing import Any from litellm import Router from litellm.utils import ImageResponse -from app.services.model_resolver import ( - NATIVE_PROVIDER_PREFIX, - native_connection_from_config, - to_litellm, -) +from app.services.model_resolver import native_connection_from_config, to_litellm logger = logging.getLogger(__name__) # Special ID for Auto mode - uses router for load balancing IMAGE_GEN_AUTO_MODE_ID = 0 -# Provider mapping for LiteLLM model string construction. -# Only includes providers that support image generation. -# See: https://docs.litellm.ai/docs/image_generation#supported-providers -IMAGE_GEN_PROVIDER_MAP = NATIVE_PROVIDER_PREFIX - - class ImageGenRouterService: """ Singleton service for managing LiteLLM Router for image generation. diff --git a/surfsense_backend/scripts/verify_chat_image_capability.py b/surfsense_backend/scripts/verify_chat_image_capability.py index a49d4eab2..6e711f99a 100644 --- a/surfsense_backend/scripts/verify_chat_image_capability.py +++ b/surfsense_backend/scripts/verify_chat_image_capability.py @@ -55,7 +55,6 @@ from app.services.openrouter_integration_service import ( # noqa: E402 _OPENROUTER_DYNAMIC_MARKER, OpenRouterIntegrationService, ) -from app.services.provider_api_base import resolve_api_base # noqa: E402 from app.services.provider_capabilities import ( # noqa: E402 derive_supports_image_input, is_known_text_only_chat_model, @@ -154,13 +153,13 @@ def _probe_chat_capability(cfg: dict) -> tuple[bool, str]: litellm_params.get("base_model") if isinstance(litellm_params, dict) else None ) cap = derive_supports_image_input( - provider=cfg.get("provider"), + litellm_provider=cfg.get("litellm_provider"), model_name=cfg.get("model_name"), base_model=base_model, custom_provider=cfg.get("custom_provider"), ) block = is_known_text_only_chat_model( - provider=cfg.get("provider"), + litellm_provider=cfg.get("litellm_provider"), model_name=cfg.get("model_name"), base_model=base_model, custom_provider=cfg.get("custom_provider"), @@ -179,11 +178,7 @@ def _probe_chat_capability(cfg: dict) -> tuple[bool, str]: def _build_chat_model_string(cfg: dict) -> str: if cfg.get("custom_provider"): return f"{cfg['custom_provider']}/{cfg['model_name']}" - from app.services.provider_capabilities import _PROVIDER_PREFIX_MAP - - prefix = _PROVIDER_PREFIX_MAP.get( - (cfg.get("provider") or "").upper(), (cfg.get("provider") or "").lower() - ) + prefix = cfg.get("litellm_provider") or "openai" return f"{prefix}/{cfg['model_name']}" @@ -195,11 +190,6 @@ def _build_chat_model_string(cfg: dict) -> str: async def _live_chat_image_call(cfg: dict) -> tuple[bool, str]: """Send a 1x1 PNG + `reply with one word: ok` to the chat config.""" model_string = _build_chat_model_string(cfg) - api_base = resolve_api_base( - provider=cfg.get("provider"), - provider_prefix=model_string.split("/", 1)[0], - config_api_base=cfg.get("api_base") or None, - ) kwargs: dict[str, Any] = { "model": model_string, "api_key": cfg.get("api_key"), @@ -218,8 +208,8 @@ async def _live_chat_image_call(cfg: dict) -> tuple[bool, str]: "max_tokens": 16, "timeout": 60, } - if api_base: - kwargs["api_base"] = api_base + if cfg.get("api_base"): + kwargs["api_base"] = cfg["api_base"] if cfg.get("litellm_params"): # Strip pricing keys — they're tracking-only and confuse some # provider validators (e.g. azure/openai reject unknown kwargs @@ -257,20 +247,11 @@ _IMAGE_GEN_PROMPTS: tuple[str, ...] = ( async def _live_image_gen_call(cfg: dict) -> tuple[bool, str]: """Generate one tiny image to verify the deployment is reachable.""" - from app.services.provider_capabilities import _PROVIDER_PREFIX_MAP - if cfg.get("custom_provider"): prefix = cfg["custom_provider"] else: - prefix = _PROVIDER_PREFIX_MAP.get( - (cfg.get("provider") or "").upper(), (cfg.get("provider") or "").lower() - ) + prefix = cfg.get("litellm_provider") or "openai" model_string = f"{prefix}/{cfg['model_name']}" - api_base = resolve_api_base( - provider=cfg.get("provider"), - provider_prefix=prefix, - config_api_base=cfg.get("api_base") or None, - ) base_kwargs: dict[str, Any] = { "model": model_string, "api_key": cfg.get("api_key"), @@ -278,8 +259,8 @@ async def _live_image_gen_call(cfg: dict) -> tuple[bool, str]: "size": "1024x1024", "timeout": 120, } - if api_base: - base_kwargs["api_base"] = api_base + if cfg.get("api_base"): + base_kwargs["api_base"] = cfg["api_base"] if cfg.get("api_version"): base_kwargs["api_version"] = cfg["api_version"] if cfg.get("litellm_params"): diff --git a/surfsense_backend/tests/unit/routes/test_image_gen_quota.py b/surfsense_backend/tests/unit/routes/test_image_gen_quota.py index 636b7de31..53c0f50a9 100644 --- a/surfsense_backend/tests/unit/routes/test_image_gen_quota.py +++ b/surfsense_backend/tests/unit/routes/test_image_gen_quota.py @@ -49,14 +49,14 @@ async def test_resolve_billing_for_premium_global_config(monkeypatch): [ { "id": -1, - "provider": "OPENAI", + "litellm_provider": "openai", "model_name": "gpt-image-1", "billing_tier": "premium", "quota_reserve_micros": 75_000, }, { "id": -2, - "provider": "OPENROUTER", + "litellm_provider": "openrouter", "model_name": "google/gemini-2.5-flash-image", "billing_tier": "free", }, @@ -118,7 +118,7 @@ async def test_resolve_billing_falls_back_to_search_space_default(monkeypatch): [ { "id": -7, - "provider": "OPENAI", + "litellm_provider": "openai", "model_name": "gpt-image-1", "billing_tier": "premium", } diff --git a/surfsense_backend/tests/unit/services/test_image_gen_api_base_defense.py b/surfsense_backend/tests/unit/services/test_image_gen_api_base_defense.py index 571e7d15b..63aa934a3 100644 --- a/surfsense_backend/tests/unit/services/test_image_gen_api_base_defense.py +++ b/surfsense_backend/tests/unit/services/test_image_gen_api_base_defense.py @@ -1,19 +1,4 @@ -"""Defense-in-depth: image-gen call sites must not let an empty -``api_base`` fall through to LiteLLM's module-global ``litellm.api_base``. - -The bug repro: an OpenRouter image-gen config ships -``api_base=""``. The pre-fix call site in -``image_generation_routes._execute_image_generation`` did -``if cfg.get("api_base"): kwargs["api_base"] = cfg["api_base"]`` which -silently dropped the empty string. LiteLLM then fell back to -``litellm.api_base`` (commonly inherited from ``AZURE_OPENAI_ENDPOINT``) -and OpenRouter's ``image_generation/transformation`` appended -``/chat/completions`` to it → 404 ``Resource not found``. - -This test pins the post-fix behaviour: with an empty ``api_base`` in -the config, the call site MUST set ``api_base`` to OpenRouter's public -URL instead of leaving it unset. -""" +"""Image-gen call sites must pass each config's explicit ``api_base``.""" from __future__ import annotations @@ -26,20 +11,17 @@ pytestmark = pytest.mark.unit @pytest.mark.asyncio -async def test_global_openrouter_image_gen_sets_api_base_when_config_empty(): - """The global-config branch (``config_id < 0``) of - ``_execute_image_generation`` must apply the resolver and pin - ``api_base`` to OpenRouter when the config ships an empty string. - """ +async def test_global_openrouter_image_gen_sets_explicit_api_base(): + """The global-config branch forwards the explicit OpenRouter base.""" from app.routes import image_generation_routes cfg = { "id": -20_001, "name": "GPT Image 1 (OpenRouter)", - "provider": "OPENROUTER", + "litellm_provider": "openrouter", "model_name": "openai/gpt-image-1", "api_key": "sk-or-test", - "api_base": "", # the original bug shape + "api_base": "https://openrouter.ai/api/v1", "api_version": None, "litellm_params": {}, } @@ -80,16 +62,13 @@ async def test_global_openrouter_image_gen_sets_api_base_when_config_empty(): session=session, image_gen=image_gen, search_space=search_space ) - # The whole point of the fix: even with empty ``api_base`` in the - # config, we forward OpenRouter's public URL so the call doesn't - # inherit an Azure endpoint. assert captured.get("api_base") == "https://openrouter.ai/api/v1" assert captured["model"] == "openrouter/openai/gpt-image-1" @pytest.mark.asyncio -async def test_generate_image_tool_global_sets_api_base_when_config_empty(): - """Same defense at the agent tool entry point — both surfaces share +async def test_generate_image_tool_global_sets_explicit_api_base(): + """Same explicit-base behavior at the agent tool entry point — both surfaces share the same OpenRouter config payloads.""" from app.agents.chat.multi_agent_chat.subagents.builtins.deliverables.tools import ( generate_image as gi_module, @@ -98,10 +77,10 @@ async def test_generate_image_tool_global_sets_api_base_when_config_empty(): cfg = { "id": -20_001, "name": "GPT Image 1 (OpenRouter)", - "provider": "OPENROUTER", + "litellm_provider": "openrouter", "model_name": "openai/gpt-image-1", "api_key": "sk-or-test", - "api_base": "", + "api_base": "https://openrouter.ai/api/v1", "api_version": None, "litellm_params": {}, } @@ -171,20 +150,16 @@ async def test_generate_image_tool_global_sets_api_base_when_config_empty(): assert captured["model"] == "openrouter/openai/gpt-image-1" -def test_image_gen_router_deployment_sets_api_base_when_config_empty(): - """The Auto-mode router pool must also resolve ``api_base`` when an - OpenRouter config ships an empty string. The deployment dict is fed - straight to ``litellm.Router``, so a missing ``api_base`` would - leak the same way as the direct call sites. - """ +def test_image_gen_router_deployment_sets_explicit_api_base(): + """The Auto-mode router pool carries explicit api_base into deployments.""" from app.services.image_gen_router_service import ImageGenRouterService deployment = ImageGenRouterService._config_to_deployment( { "model_name": "openai/gpt-image-1", - "provider": "OPENROUTER", + "litellm_provider": "openrouter", "api_key": "sk-or-test", - "api_base": "", + "api_base": "https://openrouter.ai/api/v1", } ) assert deployment is not None diff --git a/surfsense_backend/tests/unit/services/test_vision_llm_api_base_defense.py b/surfsense_backend/tests/unit/services/test_vision_llm_api_base_defense.py index 5e3aa6eda..48dfc8e0b 100644 --- a/surfsense_backend/tests/unit/services/test_vision_llm_api_base_defense.py +++ b/surfsense_backend/tests/unit/services/test_vision_llm_api_base_defense.py @@ -1,12 +1,4 @@ -"""Defense-in-depth: vision-LLM resolution must not leak ``api_base`` -defaults from ``litellm.api_base`` either. - -Vision shares the same shape as image-gen — global YAML / OpenRouter -dynamic configs ship ``api_base=""`` and the pre-fix ``get_vision_llm`` -call sites would silently drop the empty string and inherit -``AZURE_OPENAI_ENDPOINT``. ``ChatLiteLLM(...)`` doesn't 404 on -construction so we test the kwargs we hand to it instead. -""" +"""Vision LLM resolution must pass explicit per-config ``api_base``.""" from __future__ import annotations @@ -19,19 +11,16 @@ pytestmark = pytest.mark.unit @pytest.mark.asyncio async def test_get_vision_llm_global_openrouter_sets_api_base(): - """Global negative-ID branch: an OpenRouter vision config with - ``api_base=""`` must end up calling ``SanitizedChatLiteLLM`` with - ``api_base="https://openrouter.ai/api/v1"`` — never an empty string, - never silently absent.""" + """Global negative-ID branch forwards the explicit OpenRouter base.""" from app.services import llm_service cfg = { "id": -30_001, "name": "GPT-4o Vision (OpenRouter)", - "provider": "OPENROUTER", + "litellm_provider": "openrouter", "model_name": "openai/gpt-4o", "api_key": "sk-or-test", - "api_base": "", + "api_base": "https://openrouter.ai/api/v1", "api_version": None, "litellm_params": {}, "billing_tier": "free", @@ -72,16 +61,15 @@ async def test_get_vision_llm_global_openrouter_sets_api_base(): def test_vision_router_deployment_sets_api_base_when_config_empty(): - """Auto-mode vision router: deployments are fed to ``litellm.Router``, - so the resolver has to apply at deployment construction time too.""" + """Auto-mode vision router carries explicit api_base into deployments.""" from app.services.vision_llm_router_service import VisionLLMRouterService deployment = VisionLLMRouterService._config_to_deployment( { "model_name": "openai/gpt-4o", - "provider": "OPENROUTER", + "litellm_provider": "openrouter", "api_key": "sk-or-test", - "api_base": "", + "api_base": "https://openrouter.ai/api/v1", } ) assert deployment is not None From 3089dd4cb6b9ba1bc8132ef50967bd70b6a3f33b Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Thu, 11 Jun 2026 18:22:57 +0530 Subject: [PATCH 021/212] refactor(model-connections): simplify connection settings UI --- .../model-connections-mutation.atoms.ts | 12 +- .../components/new-chat/model-selector.tsx | 4 +- .../settings/model-connections-settings.tsx | 234 +++++++----------- .../types/model-connections.types.ts | 8 +- .../hooks/use-automation-eligible-models.ts | 2 +- .../lib/apis/model-connections-api.service.ts | 31 ++- 6 files changed, 126 insertions(+), 165 deletions(-) diff --git a/surfsense_web/atoms/model-connections/model-connections-mutation.atoms.ts b/surfsense_web/atoms/model-connections/model-connections-mutation.atoms.ts index 76289e60d..101bad1b5 100644 --- a/surfsense_web/atoms/model-connections/model-connections-mutation.atoms.ts +++ b/surfsense_web/atoms/model-connections/model-connections-mutation.atoms.ts @@ -4,8 +4,10 @@ import type { ConnectionCreateRequest, ConnectionUpdateRequest, ModelCreateRequest, + ModelRead, ModelRoles, ModelUpdateRequest, + VerifyConnectionResponse, } from "@/contracts/types/model-connections.types"; import { modelConnectionsApiService } from "@/lib/apis/model-connections-api.service"; import { cacheKeys } from "@/lib/query-client/cache-keys"; @@ -67,7 +69,7 @@ export const verifyModelConnectionMutationAtom = atomWithMutation((get) => { return { mutationKey: ["model-connections", "verify"], mutationFn: (id: number) => modelConnectionsApiService.verifyConnection(id), - onSuccess: (result) => { + onSuccess: (result: VerifyConnectionResponse) => { if (result.ok) { toast.success("Connection verified"); } else { @@ -90,11 +92,9 @@ export const discoverConnectionModelsMutationAtom = atomWithMutation((get) => { return { mutationKey: ["model-connections", "discover"], mutationFn: (id: number) => modelConnectionsApiService.discoverModels(id), - onSuccess: (models) => { + onSuccess: (models: ModelRead[]) => { toast.success( - models.length - ? `${models.length} models discovered` - : "No models found for this connection" + models.length ? `${models.length} models discovered` : "No models found for this connection" ); invalidateModelConnections(searchSpaceId); }, @@ -132,7 +132,7 @@ export const testModelMutationAtom = atomWithMutation((get) => { return { mutationKey: ["models", "test"], mutationFn: (id: number) => modelConnectionsApiService.testModel(id), - onSuccess: (result) => { + onSuccess: (result: VerifyConnectionResponse) => { if (result.ok) toast.success("Model test succeeded"); else toast.error(result.message || "Model test failed"); invalidateModelConnections(searchSpaceId); diff --git a/surfsense_web/components/new-chat/model-selector.tsx b/surfsense_web/components/new-chat/model-selector.tsx index 7c912afbb..4744da617 100644 --- a/surfsense_web/components/new-chat/model-selector.tsx +++ b/surfsense_web/components/new-chat/model-selector.tsx @@ -56,7 +56,7 @@ function modelName(model: ModelRead) { function connectionLabel(connection: ConnectionRead) { if (connection.scope === "GLOBAL") return "Hosted"; - return connection.native_provider || connection.protocol; + return connection.litellm_provider || connection.protocol; } function flattenChatModels(connections: ConnectionRead[]) { @@ -67,7 +67,7 @@ function flattenChatModels(connections: ConnectionRead[]) { ...model, connectionId: connection.id, connectionLabel: connectionLabel(connection), - provider: connection.native_provider || connection.protocol, + provider: connection.litellm_provider || connection.protocol, })) ); } diff --git a/surfsense_web/components/settings/model-connections-settings.tsx b/surfsense_web/components/settings/model-connections-settings.tsx index 59873408f..29501abda 100644 --- a/surfsense_web/components/settings/model-connections-settings.tsx +++ b/surfsense_web/components/settings/model-connections-settings.tsx @@ -2,7 +2,7 @@ import { useAtom, useAtomValue } from "jotai"; import { CheckCircle2, PlugZap, Plus, RefreshCcw, XCircle } from "lucide-react"; -import { useMemo, useState } from "react"; +import { useState } from "react"; import { addManualModelMutationAtom, createModelConnectionMutationAtom, @@ -35,60 +35,32 @@ import type { ConnectionRead, ModelRead, } from "@/contracts/types/model-connections.types"; -import { isCloud } from "@/lib/env-config"; import { getProviderIcon } from "@/lib/provider-icons"; -type Preset = { - id: string; - label: string; - protocol: ConnectionProtocol; - nativeProvider?: string; - baseUrl?: string; - local?: boolean; -}; - -const PRESETS: Preset[] = [ - { id: "custom", label: "OpenAI-compatible (any URL)", protocol: "OPENAI_COMPATIBLE" }, - { id: "openai", label: "OpenAI", protocol: "NATIVE", nativeProvider: "OPENAI" }, - { id: "anthropic", label: "Anthropic", protocol: "NATIVE", nativeProvider: "ANTHROPIC" }, - { id: "openrouter", label: "OpenRouter", protocol: "NATIVE", nativeProvider: "OPENROUTER" }, +const PROTOCOL_OPTIONS: { value: ConnectionProtocol; label: string; description: string }[] = [ { - id: "ollama", + value: "OPENAI_COMPATIBLE", + label: "OpenAI-compatible", + description: "Use for OpenAI, OpenRouter, Groq, vLLM, LM Studio, and compatible APIs.", + }, + { + value: "ANTHROPIC", + label: "Anthropic", + description: "Use for Claude endpoints that require Anthropic headers.", + }, + { + value: "OLLAMA", label: "Ollama", - protocol: "OLLAMA", - baseUrl: "http://host.docker.internal:11434", - local: true, - }, - { - id: "lmstudio", - label: "LM Studio", - protocol: "OPENAI_COMPATIBLE", - baseUrl: "http://host.docker.internal:1234/v1", - local: true, - }, - { - id: "llamacpp", - label: "llama.cpp", - protocol: "OPENAI_COMPATIBLE", - baseUrl: "http://host.docker.internal:8080/v1", - local: true, - }, - { - id: "localai", - label: "LocalAI", - protocol: "OPENAI_COMPATIBLE", - baseUrl: "http://host.docker.internal:8080/v1", - local: true, - }, - { - id: "vllm", - label: "vLLM", - protocol: "OPENAI_COMPATIBLE", - baseUrl: "http://host.docker.internal:8000/v1", - local: true, + description: "Use for Ollama's native API.", }, ]; +function defaultLitellmProvider(protocol: ConnectionProtocol) { + if (protocol === "OLLAMA") return "ollama_chat"; + if (protocol === "ANTHROPIC") return "anthropic"; + return "openai"; +} + // Free-text URL hints (datalist), mirroring OpenWebUI. These never restrict // what the user can type — any OpenAI-compatible endpoint works. const URL_SUGGESTIONS = [ @@ -135,9 +107,9 @@ function flattenModels(connections: ConnectionRead[]) { return connections.flatMap((connection) => connection.models.map((model) => ({ ...model, - connectionName: connection.native_provider || connection.protocol, + connectionName: connection.litellm_provider || connection.protocol, connectionId: connection.id, - provider: connection.native_provider || connection.protocol, + provider: connection.litellm_provider || connection.protocol, })) ); } @@ -156,7 +128,7 @@ function ConnectionCard({ connection }: { connection: ConnectionRead }) { const [allowlistText, setAllowlistText] = useState(allowlist.join(", ")); const [manualModelId, setManualModelId] = useState(""); - const providerLabel = connection.native_provider || connection.protocol; + const providerLabel = connection.litellm_provider || connection.protocol; const isLocal = connection.protocol === "OLLAMA" || !connection.base_url?.startsWith("https"); function saveAllowlist() { @@ -200,11 +172,7 @@ function ConnectionCard({ connection }: { connection: ConnectionRead }) { > Test -
@@ -212,8 +180,8 @@ function ConnectionCard({ connection }: { connection: ConnectionRead }) { {connection.last_status && connection.last_status !== "OK" ? (

- {connection.last_error || "Could not list models."} Chat may still work — add model - IDs manually below. + {connection.last_error || "Could not list models."} Chat may still work — add model IDs + manually below.

) : null} @@ -236,8 +204,8 @@ function ConnectionCard({ connection }: { connection: ConnectionRead }) {

- Leave empty to discover all models. Recommended for providers with large catalogs - (e.g. OpenRouter). + Leave empty to discover all models. Recommended for providers with large catalogs (e.g. + OpenRouter).

) : null} @@ -314,20 +282,14 @@ export function ModelConnectionsSettings({ searchSpaceId }: { searchSpaceId: num const createConnection = useAtomValue(createModelConnectionMutationAtom); const updateRoles = useAtomValue(updateModelRolesMutationAtom); - const visiblePresets = useMemo( - () => PRESETS.filter((preset) => !(isCloud() && preset.local)), - [] - ); - const [presetId, setPresetId] = useState(visiblePresets[0]?.id ?? "custom"); - const preset = visiblePresets.find((item) => item.id === presetId) ?? visiblePresets[0]; - const [baseUrl, setBaseUrl] = useState(preset?.baseUrl ?? ""); + const [protocol, setProtocol] = useState("OPENAI_COMPATIBLE"); + const [baseUrl, setBaseUrl] = useState(""); const [apiKey, setApiKey] = useState(""); - // Native providers carry their endpoint inside LiteLLM, so Base URL is hidden - // by default and only revealed for power users who want to override it. - const [showCustomEndpoint, setShowCustomEndpoint] = useState(false); - - const isNative = preset?.protocol === "NATIVE"; - const requiresUrl = !isNative; + const [litellmProvider, setLitellmProvider] = useState(""); + const [showAdvancedProvider, setShowAdvancedProvider] = useState(false); + const selectedProtocol = PROTOCOL_OPTIONS.find((item) => item.value === protocol); + const protocolDefaultProvider = defaultLitellmProvider(protocol); + const isOllama = protocol === "OLLAMA"; const allConnections = [...globalConnections, ...connections]; const enabledModels = flattenModels(allConnections).filter((model) => model.enabled); @@ -335,21 +297,12 @@ export function ModelConnectionsSettings({ searchSpaceId }: { searchSpaceId: num const visionModels = enabledModels.filter((model) => capability(model, "vision")); const imageModels = enabledModels.filter((model) => capability(model, "image_gen")); - function onPresetChange(value: string) { - setPresetId(value); - const next = visiblePresets.find((item) => item.id === value); - // Native providers use LiteLLM's built-in endpoint; everything else needs - // (and may prefill) a Base URL. - setBaseUrl(next?.protocol === "NATIVE" ? "" : (next?.baseUrl ?? "")); - setShowCustomEndpoint(false); - } - function handleCreate() { - if (!preset) return; + const explicitProvider = litellmProvider.trim(); createConnection.mutate( { - protocol: preset.protocol, - native_provider: preset.nativeProvider, + protocol, + litellm_provider: explicitProvider ? explicitProvider : null, base_url: baseUrl || null, api_key: apiKey || null, scope: "SEARCH_SPACE", @@ -384,90 +337,89 @@ export function ModelConnectionsSettings({ searchSpaceId }: { searchSpaceId: num
- - setProtocol(value as ConnectionProtocol)} + > - {visiblePresets.map((item) => ( - - - {getProviderIcon(item.nativeProvider || item.protocol, { - className: "size-4", - })} - {item.label} - + {PROTOCOL_OPTIONS.map((item) => ( + + {item.label} ))}
- - {isNative && !showCustomEndpoint ? ( -
-
- Uses provider default -
- -
- ) : ( - <> - setBaseUrl(event.target.value)} - placeholder="https://api.example.com/v1" - list="model-conn-url-suggestions" - /> - - {URL_SUGGESTIONS.map((url) => ( - - - )} + + setBaseUrl(event.target.value)} + placeholder={ + isOllama ? "http://host.docker.internal:11434" : "https://api.example.com/v1" + } + list="model-conn-url-suggestions" + /> + + {URL_SUGGESTIONS.map((url) => ( +
- + setApiKey(event.target.value)} - placeholder={preset?.local ? "Optional for local models" : "API key"} + placeholder={isOllama ? "Optional for Ollama" : "API key"} type="password" />
- {preset?.local ? ( +

- Local URLs are tested from the backend container. Use host.docker.internal instead of - localhost. + {selectedProtocol?.description} Base URL is explicit and editable; no provider presets + are required. Local URLs are tested from the backend container, so use + host.docker.internal instead of localhost.

- ) : isNative ? ( -

- Just paste an API key — {preset?.label} routes through its native endpoint - automatically. After adding, hit Discover (or add model IDs manually). -

- ) : preset?.protocol === "OPENAI_COMPATIBLE" ? ( -

- Enter any OpenAI-compatible endpoint (OpenRouter, Together, Groq, vLLM, LM Studio…). - After adding, hit Discover to list models. -

- ) : null} +
+ + {showAdvancedProvider ? ( +
+ + setLitellmProvider(event.target.value)} + placeholder={protocolDefaultProvider} + /> +

+ Leave empty to use the protocol default. Set this for more accurate LiteLLM + capabilities/costs, for example openrouter, groq, gemini, or azure. +

+
+ ) : null} +
+
{connections.map((connection) => ( diff --git a/surfsense_web/contracts/types/model-connections.types.ts b/surfsense_web/contracts/types/model-connections.types.ts index dcc875251..7a37799c4 100644 --- a/surfsense_web/contracts/types/model-connections.types.ts +++ b/surfsense_web/contracts/types/model-connections.types.ts @@ -1,6 +1,6 @@ import { z } from "zod"; -export const connectionProtocolEnum = z.enum(["OLLAMA", "OPENAI_COMPATIBLE", "NATIVE"]); +export const connectionProtocolEnum = z.enum(["OLLAMA", "OPENAI_COMPATIBLE", "ANTHROPIC"]); export const connectionScopeEnum = z.enum(["GLOBAL", "SEARCH_SPACE", "USER"]); export const modelSourceEnum = z.enum(["DISCOVERED", "MANUAL"]); @@ -32,7 +32,7 @@ export const modelRead = z.object({ export const connectionRead = z.object({ id: z.number(), protocol: z.union([connectionProtocolEnum, z.string()]), - native_provider: z.string().nullable().optional(), + litellm_provider: z.string().nullable().optional(), base_url: z.string().nullable().optional(), extra: z.record(z.string(), z.any()).default({}), scope: z.union([connectionScopeEnum, z.string()]), @@ -49,7 +49,7 @@ export const connectionRead = z.object({ export const connectionCreateRequest = z.object({ protocol: connectionProtocolEnum, - native_provider: z.string().nullable().optional(), + litellm_provider: z.string().nullable().optional(), base_url: z.string().nullable().optional(), api_key: z.string().nullable().optional(), extra: z.record(z.string(), z.any()).default({}), @@ -59,7 +59,7 @@ export const connectionCreateRequest = z.object({ }); export const connectionUpdateRequest = z.object({ - native_provider: z.string().nullable().optional(), + litellm_provider: z.string().nullable().optional(), base_url: z.string().nullable().optional(), api_key: z.string().nullable().optional(), extra: z.record(z.string(), z.any()).optional(), diff --git a/surfsense_web/hooks/use-automation-eligible-models.ts b/surfsense_web/hooks/use-automation-eligible-models.ts index e75235c56..f8b264162 100644 --- a/surfsense_web/hooks/use-automation-eligible-models.ts +++ b/surfsense_web/hooks/use-automation-eligible-models.ts @@ -51,7 +51,7 @@ function buildKind( id: model.id, name: model.display_name || model.model_id, modelName: model.model_id, - provider: connection.native_provider || connection.protocol, + provider: connection.litellm_provider || connection.protocol, isBYOK, }); diff --git a/surfsense_web/lib/apis/model-connections-api.service.ts b/surfsense_web/lib/apis/model-connections-api.service.ts index 7d0f0f59c..92eec8e61 100644 --- a/surfsense_web/lib/apis/model-connections-api.service.ts +++ b/surfsense_web/lib/apis/model-connections-api.service.ts @@ -1,11 +1,13 @@ import { type ConnectionCreateRequest, + type ConnectionRead, type ConnectionUpdateRequest, connectionCreateRequest, connectionListResponse, connectionRead, connectionUpdateRequest, type ModelCreateRequest, + type ModelRead, type ModelRoles, type ModelUpdateRequest, modelCreateRequest, @@ -13,24 +15,25 @@ import { modelRead, modelRoles, modelUpdateRequest, + type VerifyConnectionResponse, verifyConnectionResponse, } from "@/contracts/types/model-connections.types"; import { ValidationError } from "../error"; import { baseApiService } from "./base-api.service"; class ModelConnectionsApiService { - getGlobalConnections = async () => { + getGlobalConnections = async (): Promise => { return baseApiService.get(`/api/v1/global-model-connections`, connectionListResponse); }; - getConnections = async (searchSpaceId: number) => { + getConnections = async (searchSpaceId: number): Promise => { return baseApiService.get( `/api/v1/model-connections?search_space_id=${searchSpaceId}`, connectionListResponse ); }; - createConnection = async (request: ConnectionCreateRequest) => { + createConnection = async (request: ConnectionCreateRequest): Promise => { const parsed = connectionCreateRequest.safeParse(request); if (!parsed.success) { throw new ValidationError(parsed.error.issues.map((issue) => issue.message).join(", ")); @@ -40,7 +43,10 @@ class ModelConnectionsApiService { }); }; - updateConnection = async (id: number, request: ConnectionUpdateRequest) => { + updateConnection = async ( + id: number, + request: ConnectionUpdateRequest + ): Promise => { const parsed = connectionUpdateRequest.safeParse(request); if (!parsed.success) { throw new ValidationError(parsed.error.issues.map((issue) => issue.message).join(", ")); @@ -54,15 +60,18 @@ class ModelConnectionsApiService { return baseApiService.delete(`/api/v1/model-connections/${id}`, undefined); }; - verifyConnection = async (id: number) => { + verifyConnection = async (id: number): Promise => { return baseApiService.post(`/api/v1/model-connections/${id}/verify`, verifyConnectionResponse); }; - discoverModels = async (id: number) => { + discoverModels = async (id: number): Promise => { return baseApiService.post(`/api/v1/model-connections/${id}/discover`, modelListResponse); }; - addManualModel = async (connectionId: number, request: ModelCreateRequest) => { + addManualModel = async ( + connectionId: number, + request: ModelCreateRequest + ): Promise => { const parsed = modelCreateRequest.safeParse(request); if (!parsed.success) { throw new ValidationError(parsed.error.issues.map((issue) => issue.message).join(", ")); @@ -72,7 +81,7 @@ class ModelConnectionsApiService { }); }; - updateModel = async (id: number, request: ModelUpdateRequest) => { + updateModel = async (id: number, request: ModelUpdateRequest): Promise => { const parsed = modelUpdateRequest.safeParse(request); if (!parsed.success) { throw new ValidationError(parsed.error.issues.map((issue) => issue.message).join(", ")); @@ -82,15 +91,15 @@ class ModelConnectionsApiService { }); }; - testModel = async (id: number) => { + testModel = async (id: number): Promise => { return baseApiService.post(`/api/v1/models/${id}/test`, verifyConnectionResponse); }; - getModelRoles = async (searchSpaceId: number) => { + getModelRoles = async (searchSpaceId: number): Promise => { return baseApiService.get(`/api/v1/search-spaces/${searchSpaceId}/model-roles`, modelRoles); }; - updateModelRoles = async (searchSpaceId: number, roles: ModelRoles) => { + updateModelRoles = async (searchSpaceId: number, roles: ModelRoles): Promise => { return baseApiService.put(`/api/v1/search-spaces/${searchSpaceId}/model-roles`, modelRoles, { body: roles, }); From 5d5d574550db44419cd5a3e112a2329d8c8923ea Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Fri, 12 Jun 2026 02:17:22 +0530 Subject: [PATCH 022/212] refactor(model-connections): move backend model connections to provider capabilities --- .../versions/156_add_model_connections.py | 121 ++++--- .../deliverables/tools/generate_image.py | 7 +- .../app/agents/chat/runtime/llm_config.py | 43 ++- surfsense_backend/app/app.py | 2 - surfsense_backend/app/celery_app.py | 2 - surfsense_backend/app/config/__init__.py | 50 +-- surfsense_backend/app/db.py | 17 +- .../app/routes/anonymous_chat_routes.py | 4 +- .../app/routes/image_generation_routes.py | 7 +- .../app/routes/model_connections_routes.py | 98 ++++-- .../app/routes/new_chat_routes.py | 13 +- .../app/routes/new_llm_config_routes.py | 6 +- .../app/routes/search_spaces_routes.py | 6 +- .../app/routes/vision_llm_routes.py | 2 +- surfsense_backend/app/schemas/__init__.py | 1 + .../app/schemas/model_connections.py | 33 +- .../app/services/auto_model_pin_service.py | 30 +- .../app/services/global_model_catalog.py | 28 +- surfsense_backend/app/services/llm_service.py | 14 +- .../app/services/model_capabilities.py | 36 +++ .../app/services/model_connection_service.py | 301 +++++++++++------- .../app/services/model_list_service.py | 23 +- .../app/services/model_resolver.py | 50 +-- .../openrouter_integration_service.py | 90 ++---- .../services/openrouter_model_normalizer.py | 121 +++++++ .../app/services/pricing_registration.py | 16 +- .../app/services/provider_capabilities.py | 12 +- .../app/services/provider_registry.py | 98 ++++++ .../app/services/quality_score.py | 2 +- .../flows/new_chat/llm_capability.py | 2 +- .../chat/streaming/flows/shared/llm_bundle.py | 13 +- 31 files changed, 772 insertions(+), 476 deletions(-) create mode 100644 surfsense_backend/app/services/model_capabilities.py create mode 100644 surfsense_backend/app/services/openrouter_model_normalizer.py create mode 100644 surfsense_backend/app/services/provider_registry.py diff --git a/surfsense_backend/alembic/versions/156_add_model_connections.py b/surfsense_backend/alembic/versions/156_add_model_connections.py index 185debca4..64614db99 100644 --- a/surfsense_backend/alembic/versions/156_add_model_connections.py +++ b/surfsense_backend/alembic/versions/156_add_model_connections.py @@ -17,13 +17,6 @@ branch_labels: str | Sequence[str] | None = None depends_on: str | Sequence[str] | None = None -connection_protocol = postgresql.ENUM( - "OLLAMA", - "OPENAI_COMPATIBLE", - "ANTHROPIC", - name="connectionprotocol", - create_type=False, -) connection_scope = postgresql.ENUM( "GLOBAL", "SEARCH_SPACE", @@ -73,36 +66,67 @@ def _add_searchspace_column_if_missing(column_name: str) -> None: op.add_column("searchspaces", sa.Column(column_name, sa.Integer(), nullable=True)) +def _drop_column_if_exists(table_name: str, column_name: str) -> None: + if _column_exists(table_name, column_name): + op.drop_column(table_name, column_name) + + +def _drop_index_if_exists(table_name: str, index_name: str) -> None: + if _index_exists(table_name, index_name): + op.drop_index(index_name, table_name=table_name) + + def upgrade() -> None: bind = op.get_bind() - connection_protocol.create(bind, checkfirst=True) - op.execute("ALTER TYPE connectionprotocol ADD VALUE IF NOT EXISTS 'ANTHROPIC'") connection_scope.create(bind, checkfirst=True) model_source.create(bind, checkfirst=True) if _table_exists("connections"): - if _column_exists("connections", "native_provider") and not _column_exists( - "connections", "litellm_provider" + if _column_exists("connections", "litellm_provider") and not _column_exists( + "connections", "provider" + ): + op.alter_column( + "connections", + "litellm_provider", + new_column_name="provider", + existing_type=sa.String(length=100), + existing_nullable=True, + ) + op.alter_column( + "connections", + "provider", + existing_type=sa.String(length=100), + nullable=False, + ) + elif _column_exists("connections", "native_provider") and not _column_exists( + "connections", "provider" ): op.alter_column( "connections", "native_provider", - new_column_name="litellm_provider", + new_column_name="provider", existing_type=sa.String(length=100), existing_nullable=True, ) - elif not _column_exists("connections", "litellm_provider"): + op.alter_column( + "connections", + "provider", + existing_type=sa.String(length=100), + nullable=False, + ) + elif not _column_exists("connections", "provider"): op.add_column( "connections", - sa.Column("litellm_provider", sa.String(length=100), nullable=True), + sa.Column("provider", sa.String(length=100), nullable=False), ) + _drop_index_if_exists("connections", "ix_connections_protocol") + _drop_column_if_exists("connections", "protocol") else: op.create_table( "connections", sa.Column("id", sa.Integer(), nullable=False), sa.Column("created_at", sa.DateTime(timezone=True), nullable=False), - sa.Column("protocol", connection_protocol, nullable=False), - sa.Column("litellm_provider", sa.String(length=100), nullable=True), + sa.Column("provider", sa.String(length=100), nullable=False), sa.Column("base_url", sa.String(length=500), nullable=True), sa.Column("api_key", sa.String(), nullable=True), sa.Column( @@ -131,18 +155,20 @@ def upgrade() -> None: sa.PrimaryKeyConstraint("id"), ) if _index_exists("connections", "ix_connections_native_provider") and not _index_exists( - "connections", "ix_connections_litellm_provider" + "connections", "ix_connections_provider" ): op.execute( "ALTER INDEX ix_connections_native_provider " - "RENAME TO ix_connections_litellm_provider" + "RENAME TO ix_connections_provider" ) - _create_index_if_missing("ix_connections_protocol", "connections", ["protocol"]) - _create_index_if_missing( - "ix_connections_litellm_provider", - "connections", - ["litellm_provider"], - ) + if _index_exists("connections", "ix_connections_litellm_provider") and not _index_exists( + "connections", "ix_connections_provider" + ): + op.execute( + "ALTER INDEX ix_connections_litellm_provider " + "RENAME TO ix_connections_provider" + ) + _create_index_if_missing("ix_connections_provider", "connections", ["provider"]) _create_index_if_missing("ix_connections_scope", "connections", ["scope"]) if not _table_exists("models"): @@ -159,24 +185,11 @@ def upgrade() -> None: server_default="DISCOVERED", nullable=False, ), - sa.Column( - "capabilities", - postgresql.JSONB(astext_type=sa.Text()), - server_default=sa.text("'{}'::jsonb"), - nullable=False, - ), - sa.Column( - "capabilities_declared", - postgresql.JSONB(astext_type=sa.Text()), - server_default=sa.text("'{}'::jsonb"), - nullable=False, - ), - sa.Column( - "capabilities_verified", - postgresql.JSONB(astext_type=sa.Text()), - server_default=sa.text("'{}'::jsonb"), - nullable=False, - ), + sa.Column("supports_chat", sa.Boolean(), nullable=True), + sa.Column("max_input_tokens", sa.Integer(), nullable=True), + sa.Column("supports_image_input", sa.Boolean(), nullable=True), + sa.Column("supports_tools", sa.Boolean(), nullable=True), + sa.Column("supports_image_generation", sa.Boolean(), nullable=True), sa.Column( "capabilities_override", postgresql.JSONB(astext_type=sa.Text()), @@ -198,6 +211,24 @@ def upgrade() -> None: "connection_id", "model_id", name="uq_models_connection_model_id" ), ) + else: + if not _column_exists("models", "supports_chat"): + op.add_column("models", sa.Column("supports_chat", sa.Boolean(), nullable=True)) + if not _column_exists("models", "max_input_tokens"): + op.add_column("models", sa.Column("max_input_tokens", sa.Integer(), nullable=True)) + if not _column_exists("models", "supports_image_input"): + op.add_column( + "models", sa.Column("supports_image_input", sa.Boolean(), nullable=True) + ) + if not _column_exists("models", "supports_tools"): + op.add_column("models", sa.Column("supports_tools", sa.Boolean(), nullable=True)) + if not _column_exists("models", "supports_image_generation"): + op.add_column( + "models", sa.Column("supports_image_generation", sa.Boolean(), nullable=True) + ) + _drop_column_if_exists("models", "capabilities") + _drop_column_if_exists("models", "capabilities_declared") + _drop_column_if_exists("models", "capabilities_verified") _create_index_if_missing("ix_models_connection_id", "models", ["connection_id"]) _create_index_if_missing("ix_models_model_id", "models", ["model_id"]) _create_index_if_missing("ix_models_billing_tier", "models", ["billing_tier"]) @@ -206,6 +237,8 @@ def upgrade() -> None: _add_searchspace_column_if_missing("image_gen_model_id") _add_searchspace_column_if_missing("vision_model_id") + op.execute("DROP TYPE IF EXISTS connectionprotocol") + def downgrade() -> None: op.drop_column("searchspaces", "vision_model_id") @@ -218,11 +251,9 @@ def downgrade() -> None: op.drop_table("models") op.drop_index(op.f("ix_connections_scope"), table_name="connections") - op.drop_index(op.f("ix_connections_litellm_provider"), table_name="connections") - op.drop_index(op.f("ix_connections_protocol"), table_name="connections") + op.drop_index(op.f("ix_connections_provider"), table_name="connections") op.drop_table("connections") bind = op.get_bind() model_source.drop(bind, checkfirst=True) connection_scope.drop(bind, checkfirst=True) - connection_protocol.drop(bind, checkfirst=True) diff --git a/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/deliverables/tools/generate_image.py b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/deliverables/tools/generate_image.py index fda327750..505831faa 100644 --- a/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/deliverables/tools/generate_image.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/deliverables/tools/generate_image.py @@ -29,6 +29,7 @@ from app.services.auto_model_pin_service import ( auto_model_candidates, choose_auto_model_candidate, ) +from app.services.model_capabilities import has_capability from app.services.model_resolver import to_litellm from app.utils.signed_image_urls import generate_image_token @@ -146,9 +147,7 @@ def create_generate_image_tool( if config_id < 0: global_model = _get_global_model(config_id) - if not global_model or not ( - global_model.get("capabilities") or {} - ).get("image_gen"): + if not global_model or not has_capability(global_model, "image_gen"): err = f"Image generation model {config_id} not found" return _failed({"error": err}, error=err) global_connection = _get_global_connection( @@ -191,7 +190,7 @@ def create_generate_image_tool( ): err = f"Image generation model {config_id} not found" return _failed({"error": err}, error=err) - if not (db_model.capabilities or {}).get("image_gen"): + if not has_capability(db_model, "image_gen"): err = f"Model {config_id} is not image-generation capable" return _failed({"error": err}, error=err) diff --git a/surfsense_backend/app/agents/chat/runtime/llm_config.py b/surfsense_backend/app/agents/chat/runtime/llm_config.py index b9344e001..efc188df8 100644 --- a/surfsense_backend/app/agents/chat/runtime/llm_config.py +++ b/surfsense_backend/app/agents/chat/runtime/llm_config.py @@ -49,16 +49,19 @@ def _sanitize_messages(messages: list[BaseMessage]) -> list[BaseMessage]: reject the blank text. The OpenAI spec says ``content`` should be ``null`` when an assistant message only carries tool calls. """ + sanitized: list[BaseMessage] = [] for msg in messages: - if isinstance(msg.content, list): - msg.content = _sanitize_content(msg.content) + next_msg = msg.model_copy(deep=True) + if isinstance(next_msg.content, list): + next_msg.content = _sanitize_content(next_msg.content) if ( - isinstance(msg, AIMessage) - and (not msg.content or msg.content == "") - and getattr(msg, "tool_calls", None) + isinstance(next_msg, AIMessage) + and (not next_msg.content or next_msg.content == "") + and getattr(next_msg, "tool_calls", None) ): - msg.content = None # type: ignore[assignment] - return messages + next_msg.content = None # type: ignore[assignment] + sanitized.append(next_msg) + return sanitized class SanitizedChatLiteLLM(ChatLiteLLM): @@ -89,6 +92,22 @@ class SanitizedChatLiteLLM(ChatLiteLLM): ): yield chunk + async def _agenerate( + self, + messages: list[BaseMessage], + stop: list[str] | None = None, + run_manager: AsyncCallbackManagerForLLMRun | None = None, + stream: bool | None = None, + **kwargs: Any, + ) -> ChatResult: + return await super()._agenerate( + _sanitize_messages(messages), + stop=stop, + run_manager=run_manager, + stream=stream, + **kwargs, + ) + def _attach_model_profile(llm: ChatLiteLLM, model_string: str) -> None: """Attach a ``profile`` dict to ChatLiteLLM with model context metadata.""" @@ -210,7 +229,7 @@ class AgentConfig: # BYOK rows have no curated flag; ask LiteLLM (default-allow on # unknown). The streaming safety net still blocks explicit text-only. supports_image_input=derive_supports_image_input( - litellm_provider=provider_value.lower(), + provider=provider_value.lower(), model_name=config.model_name, base_model=base_model, custom_provider=config.custom_provider, @@ -229,7 +248,7 @@ class AgentConfig: system_instructions = yaml_config.get("system_instructions", "") - provider = yaml_config.get("litellm_provider", "") + provider = yaml_config.get("provider") or yaml_config.get("litellm_provider", "") model_name = yaml_config.get("model_name", "") custom_provider = yaml_config.get("custom_provider") litellm_params = yaml_config.get("litellm_params") or {} @@ -245,7 +264,7 @@ class AgentConfig: supports_image_input = bool(yaml_config.get("supports_image_input")) else: supports_image_input = derive_supports_image_input( - litellm_provider=provider, + provider=provider, model_name=model_name, base_model=base_model, custom_provider=custom_provider, @@ -396,8 +415,8 @@ def create_chat_litellm_from_config(llm_config: dict) -> ChatLiteLLM | None: if llm_config.get("custom_provider"): model_string = f"{llm_config['custom_provider']}/{llm_config['model_name']}" else: - litellm_provider = llm_config.get("litellm_provider", "openai") - model_string = f"{litellm_provider}/{llm_config['model_name']}" + provider = llm_config.get("provider") or llm_config.get("litellm_provider", "openai") + model_string = f"{provider}/{llm_config['model_name']}" litellm_kwargs = { "model": model_string, diff --git a/surfsense_backend/app/app.py b/surfsense_backend/app/app.py index d3f5dce2a..6dfe6a776 100644 --- a/surfsense_backend/app/app.py +++ b/surfsense_backend/app/app.py @@ -33,7 +33,6 @@ from app.config import ( initialize_llm_router, initialize_openrouter_integration, initialize_pricing_registration, - initialize_vision_llm_router, ) from app.db import User, create_db_and_tables, get_async_session from app.exceptions import GENERIC_5XX_MESSAGE, ISSUES_URL, SurfSenseError @@ -622,7 +621,6 @@ async def lifespan(app: FastAPI): initialize_pricing_registration() initialize_llm_router() initialize_image_gen_router() - initialize_vision_llm_router() # Phase 1.7 — JIT warmup. Bounded so a stuck warmup never delays # worker readiness. ``shield`` so Uvicorn cancelling startup diff --git a/surfsense_backend/app/celery_app.py b/surfsense_backend/app/celery_app.py index 0e852b801..7ab98203e 100644 --- a/surfsense_backend/app/celery_app.py +++ b/surfsense_backend/app/celery_app.py @@ -115,14 +115,12 @@ def init_worker(**kwargs): initialize_llm_router, initialize_openrouter_integration, initialize_pricing_registration, - initialize_vision_llm_router, ) initialize_openrouter_integration() initialize_pricing_registration() initialize_llm_router() initialize_image_gen_router() - initialize_vision_llm_router() # Celery configuration, sourced from the central Config singleton diff --git a/surfsense_backend/app/config/__init__.py b/surfsense_backend/app/config/__init__.py index fd8b29116..eb2a7a18c 100644 --- a/surfsense_backend/app/config/__init__.py +++ b/surfsense_backend/app/config/__init__.py @@ -103,7 +103,7 @@ def load_global_llm_configs(): else None ) cfg["supports_image_input"] = derive_supports_image_input( - litellm_provider=cfg.get("litellm_provider"), + provider=cfg.get("provider") or cfg.get("litellm_provider"), model_name=cfg.get("model_name"), base_model=base_model, custom_provider=cfg.get("custom_provider"), @@ -122,7 +122,7 @@ def load_global_llm_configs(): # Stamp Auto (Fastest) ranking metadata. YAML configs are always # Tier A — operator-curated, locked first when premium-eligible. # The OpenRouter refresh tick later re-stamps health for any cfg - # whose litellm_provider == "openrouter" via _enrich_health. + # whose provider == "openrouter" via _enrich_health. try: from app.services.quality_score import static_score_yaml @@ -132,7 +132,7 @@ def load_global_llm_configs(): cfg["quality_score_static"] = static_q cfg["quality_score"] = static_q cfg["quality_score_health"] = None - # YAML cfgs whose litellm_provider is openrouter are also subject + # YAML cfgs whose provider is openrouter are also subject # to health gating against their own /endpoints data — a # hand-picked dead OR model is still dead. _enrich_health # re-stamps health_gated for them on the next refresh tick. @@ -362,8 +362,8 @@ def initialize_openrouter_integration(): else: print("Info: OpenRouter integration enabled but no models fetched") - # Image generation + vision LLM emissions are opt-in (issue L). - # Both reuse the catalogue already cached by ``service.initialize`` + # Image generation emissions reuse the catalogue already cached by + # ``service.initialize`` # so we don't make additional network calls here. if settings.get("image_generation_enabled"): try: @@ -377,18 +377,6 @@ def initialize_openrouter_integration(): except Exception as e: print(f"Warning: Failed to inject OpenRouter image-gen configs: {e}") - if settings.get("vision_enabled"): - try: - vision_configs = service.get_vision_llm_configs() - if vision_configs: - config.GLOBAL_VISION_LLM_CONFIGS.extend(vision_configs) - print( - f"Info: OpenRouter integration added {len(vision_configs)} " - f"vision LLM models" - ) - except Exception as e: - print(f"Warning: Failed to inject OpenRouter vision-LLM configs: {e}") - refresh_global_model_catalog() except Exception as e: print(f"Warning: Failed to initialize OpenRouter integration: {e}") @@ -399,7 +387,6 @@ def materialize_global_configs(): return materialize_global_model_catalog( chat_configs=getattr(config, "GLOBAL_LLM_CONFIGS", []), - vision_configs=getattr(config, "GLOBAL_VISION_LLM_CONFIGS", []), image_configs=getattr(config, "GLOBAL_IMAGE_GEN_CONFIGS", []), ) @@ -493,29 +480,9 @@ def initialize_image_gen_router(): def initialize_vision_llm_router(): - vision_configs = load_global_vision_llm_configs() - # Reuse the router settings already parsed at Config construction. The - # *configs* list is intentionally re-read from YAML (it must exclude the - # OpenRouter-injected dynamic models held in config.GLOBAL_VISION_LLM_CONFIGS). - router_settings = config.VISION_LLM_ROUTER_SETTINGS - - if not vision_configs: - print( - "Info: No global vision LLM configs found, " - "Vision LLM Auto mode will not be available" - ) - return - - try: - from app.services.vision_llm_router_service import VisionLLMRouterService - - VisionLLMRouterService.initialize(vision_configs, router_settings) - print( - f"Info: Vision LLM Router initialized with {len(vision_configs)} models " - f"(strategy: {router_settings.get('routing_strategy', 'usage-based-routing')})" - ) - except Exception as e: - print(f"Warning: Failed to initialize Vision LLM Router: {e}") + # Retired: vision Auto now uses shared capability-filtered model selection + # over GLOBAL/BYOK chat models with supports_image_input=true. + return class Config: @@ -874,7 +841,6 @@ class Config: GLOBAL_CONNECTIONS, GLOBAL_MODELS = _materialize_global_model_catalog( chat_configs=GLOBAL_LLM_CONFIGS, - vision_configs=GLOBAL_VISION_LLM_CONFIGS, image_configs=GLOBAL_IMAGE_GEN_CONFIGS, ) del _materialize_global_model_catalog diff --git a/surfsense_backend/app/db.py b/surfsense_backend/app/db.py index 4c628b05a..9053d7055 100644 --- a/surfsense_backend/app/db.py +++ b/surfsense_backend/app/db.py @@ -280,12 +280,6 @@ class VisionProvider(StrEnum): CUSTOM = "CUSTOM" -class ConnectionProtocol(StrEnum): - OLLAMA = "OLLAMA" - OPENAI_COMPATIBLE = "OPENAI_COMPATIBLE" - ANTHROPIC = "ANTHROPIC" - - class ConnectionScope(StrEnum): GLOBAL = "GLOBAL" SEARCH_SPACE = "SEARCH_SPACE" @@ -1662,8 +1656,7 @@ class Report(BaseModel, TimestampMixin): class Connection(BaseModel, TimestampMixin): __tablename__ = "connections" - protocol = Column(SQLAlchemyEnum(ConnectionProtocol), nullable=False, index=True) - litellm_provider = Column(String(100), nullable=True, index=True) + provider = Column(String(100), nullable=False, index=True) base_url = Column(String(500), nullable=True) api_key = Column(String, nullable=True) extra = Column(JSONB, nullable=False, default=dict, server_default="{}") @@ -1715,9 +1708,11 @@ class Model(BaseModel, TimestampMixin): default=ModelSource.DISCOVERED, server_default=ModelSource.DISCOVERED.value, ) - capabilities = Column(JSONB, nullable=False, default=dict, server_default="{}") - capabilities_declared = Column(JSONB, nullable=False, default=dict, server_default="{}") - capabilities_verified = Column(JSONB, nullable=False, default=dict, server_default="{}") + supports_chat = Column(Boolean, nullable=True) + max_input_tokens = Column(Integer, nullable=True) + supports_image_input = Column(Boolean, nullable=True) + supports_tools = Column(Boolean, nullable=True) + supports_image_generation = Column(Boolean, nullable=True) capabilities_override = Column(JSONB, nullable=False, default=dict, server_default="{}") embedding_dimension = Column(Integer, nullable=True) enabled = Column(Boolean, nullable=False, default=True, server_default="true") diff --git a/surfsense_backend/app/routes/anonymous_chat_routes.py b/surfsense_backend/app/routes/anonymous_chat_routes.py index aba1a3a12..84420e738 100644 --- a/surfsense_backend/app/routes/anonymous_chat_routes.py +++ b/surfsense_backend/app/routes/anonymous_chat_routes.py @@ -132,7 +132,7 @@ async def list_anonymous_models(): id=cfg.get("id", 0), name=cfg.get("name", ""), description=cfg.get("description"), - provider=cfg.get("litellm_provider", ""), + provider=cfg.get("provider") or cfg.get("litellm_provider", ""), model_name=cfg.get("model_name", ""), billing_tier=cfg.get("billing_tier", "free"), is_premium=cfg.get("billing_tier", "free") == "premium", @@ -161,7 +161,7 @@ async def get_anonymous_model(slug: str): id=cfg.get("id", 0), name=cfg.get("name", ""), description=cfg.get("description"), - provider=cfg.get("litellm_provider", ""), + provider=cfg.get("provider") or cfg.get("litellm_provider", ""), model_name=cfg.get("model_name", ""), billing_tier=cfg.get("billing_tier", "free"), is_premium=cfg.get("billing_tier", "free") == "premium", diff --git a/surfsense_backend/app/routes/image_generation_routes.py b/surfsense_backend/app/routes/image_generation_routes.py index 5be1cedf1..e8f14bd71 100644 --- a/surfsense_backend/app/routes/image_generation_routes.py +++ b/surfsense_backend/app/routes/image_generation_routes.py @@ -52,6 +52,7 @@ from app.services.auto_model_pin_service import ( choose_auto_model_candidate, ) from app.services.model_resolver import to_litellm +from app.services.model_capabilities import has_capability from app.users import current_active_user from app.utils.rbac import check_permission from app.utils.signed_image_urls import verify_image_token @@ -166,7 +167,7 @@ async def _execute_image_generation( if config_id < 0: global_model = _get_global_model(config_id) - if not global_model or not (global_model.get("capabilities") or {}).get("image_gen"): + if not global_model or not has_capability(global_model, "image_gen"): raise ValueError(f"Global image generation model {config_id} not found") global_connection = _get_global_connection(global_model["connection_id"]) if not global_connection: @@ -200,7 +201,7 @@ async def _execute_image_generation( raise ValueError(f"Image generation model {config_id} not found") if conn.user_id is not None and conn.user_id != search_space.user_id: raise ValueError(f"Image generation model {config_id} not found") - if not (db_model.capabilities or {}).get("image_gen"): + if not has_capability(db_model, "image_gen"): raise ValueError(f"Model {config_id} is not image-generation capable") model_string, resolved_kwargs = to_litellm( @@ -272,7 +273,7 @@ async def get_global_image_gen_configs( "id": cfg.get("id"), "name": cfg.get("name"), "description": cfg.get("description"), - "provider": cfg.get("litellm_provider"), + "provider": cfg.get("provider") or cfg.get("litellm_provider"), "custom_provider": cfg.get("custom_provider"), "model_name": cfg.get("model_name"), "api_base": cfg.get("api_base") or None, diff --git a/surfsense_backend/app/routes/model_connections_routes.py b/surfsense_backend/app/routes/model_connections_routes.py index cae951c3a..ecb86711e 100644 --- a/surfsense_backend/app/routes/model_connections_routes.py +++ b/surfsense_backend/app/routes/model_connections_routes.py @@ -8,7 +8,6 @@ from sqlalchemy.orm import selectinload from app.config import config from app.db import ( Connection, - ConnectionProtocol, ConnectionScope, Model, ModelSource, @@ -22,6 +21,7 @@ from app.schemas import ( ConnectionRead, ConnectionUpdate, ModelCreate, + ModelProviderRead, ModelRead, ModelRolesRead, ModelRolesUpdate, @@ -34,6 +34,8 @@ from app.services.model_connection_service import ( persist_verification, test_model, ) +from app.services.model_capabilities import has_capability +from app.services.provider_registry import REGISTRY from app.users import current_active_user from app.utils.rbac import check_permission @@ -41,16 +43,6 @@ router = APIRouter() logger = logging.getLogger(__name__) -def _default_litellm_provider(protocol: ConnectionProtocol | str) -> str: - protocol_value = getattr(protocol, "value", str(protocol)) - defaults = { - ConnectionProtocol.OLLAMA.value: "ollama_chat", - ConnectionProtocol.ANTHROPIC.value: "anthropic", - ConnectionProtocol.OPENAI_COMPATIBLE.value: "openai", - } - return defaults.get(protocol_value, "openai") - - def _model_read(model: Model | dict) -> ModelRead: return ModelRead.model_validate(model) @@ -68,8 +60,7 @@ def _connection_read(conn: Connection | dict, models: list[Model | dict] | None return ConnectionRead( id=conn.id, - protocol=conn.protocol, - litellm_provider=conn.litellm_provider, + provider=conn.provider, base_url=conn.base_url, extra=conn.extra or {}, scope=conn.scope, @@ -85,6 +76,60 @@ def _connection_read(conn: Connection | dict, models: list[Model | dict] | None ) +def _apply_model_facts(model: Model, facts: dict) -> None: + model.supports_chat = facts.get("supports_chat") + model.max_input_tokens = facts.get("max_input_tokens") + model.supports_image_input = facts.get("supports_image_input") + model.supports_tools = facts.get("supports_tools") + model.supports_image_generation = facts.get("supports_image_generation") + + +def _default_model_for(models: list[Model], capability: str) -> int | None: + for model in models: + if model.enabled and has_capability(model, capability): + return model.id + return None + + +async def _default_unset_roles( + session: AsyncSession, + conn: Connection, + models: list[Model], +) -> None: + if conn.scope != ConnectionScope.SEARCH_SPACE or conn.search_space_id is None: + return + search_space = await _get_search_space(session, conn.search_space_id) + if search_space.chat_model_id is None: + search_space.chat_model_id = _default_model_for(models, "chat") + if search_space.vision_model_id is None: + vision_default = None + if search_space.chat_model_id: + chat_model = next((m for m in models if m.id == search_space.chat_model_id), None) + if chat_model and has_capability(chat_model, "vision"): + vision_default = chat_model.id + search_space.vision_model_id = vision_default or _default_model_for(models, "vision") + if search_space.image_gen_model_id is None: + search_space.image_gen_model_id = _default_model_for(models, "image_gen") + + +@router.get("/model-providers", response_model=list[ModelProviderRead]) +async def list_model_providers(user: User = Depends(current_active_user)): + del user + local_only = {"ollama_chat", "lm_studio"} + return [ + ModelProviderRead( + provider=provider, + transport=spec.transport.value, + discovery=spec.discovery, + default_base_url=spec.default_base_url, + base_url_required=spec.base_url_required, + auth_style=spec.auth_style, + local_only=provider in local_only, + ) + for provider, spec in sorted(REGISTRY.items()) + ] + + async def _get_search_space(session: AsyncSession, search_space_id: int) -> SearchSpace: result = await session.execute(select(SearchSpace).where(SearchSpace.id == search_space_id)) search_space = result.scalars().first() @@ -180,8 +225,6 @@ async def create_connection( "You don't have permission to create model connections in this search space", ) payload = data.model_dump(exclude={"search_space_id"}) - if not payload.get("litellm_provider"): - payload["litellm_provider"] = _default_litellm_provider(data.protocol) conn = Connection( **payload, @@ -254,24 +297,21 @@ async def discover_connection_models( model_id=item["model_id"], display_name=item.get("display_name"), source=item["source"], - capabilities=item["capabilities"], - capabilities_declared=item["capabilities"], - capabilities_verified={}, capabilities_override={}, enabled=False, catalog=item.get("metadata") or {}, ) + _apply_model_facts(db_model, item) session.add(db_model) else: db_model.display_name = item.get("display_name") or db_model.display_name - db_model.capabilities_declared = item["capabilities"] - db_model.capabilities = { - **item["capabilities"], - **(db_model.capabilities_override or {}), - } + _apply_model_facts(db_model, item) db_model.catalog = item.get("metadata") or db_model.catalog await session.commit() conn = await _load_connection(session, connection_id) + await _default_unset_roles(session, conn, list(conn.models)) + await session.commit() + conn = await _load_connection(session, connection_id) return [_model_read(model) for model in conn.models] @@ -297,16 +337,17 @@ async def add_manual_model( model_id=model_id, display_name=data.display_name or None, source=ModelSource.MANUAL, - capabilities=capabilities, - capabilities_declared=capabilities, - capabilities_verified={}, capabilities_override={}, enabled=True, catalog={}, ) + _apply_model_facts(model, capabilities) session.add(model) await session.commit() await session.refresh(model) + conn = await _load_connection(session, connection_id) + await _default_unset_roles(session, conn, list(conn.models)) + await session.commit() return _model_read(model) @@ -327,11 +368,6 @@ async def update_model( update = data.model_dump(exclude_unset=True) for key, value in update.items(): setattr(model, key, value) - if "capabilities_override" in update: - model.capabilities = { - **(model.capabilities_declared or {}), - **(model.capabilities_override or {}), - } await session.commit() await session.refresh(model) return _model_read(model) diff --git a/surfsense_backend/app/routes/new_chat_routes.py b/surfsense_backend/app/routes/new_chat_routes.py index 0e4e557be..b5bc2571e 100644 --- a/surfsense_backend/app/routes/new_chat_routes.py +++ b/surfsense_backend/app/routes/new_chat_routes.py @@ -1741,12 +1741,11 @@ async def handle_new_chat( if not search_space: raise HTTPException(status_code=404, detail="Search space not found") - # Use agent_llm_id from search space for chat operations - # Positive IDs load from NewLLMConfig database table - # Negative IDs load from YAML global configs - # Falls back to -1 (first global config) if not configured + # Use the converged model-connections role for chat operations. + # Positive IDs load Model + Connection rows; negative IDs load + # virtual GLOBAL models; 0 means Auto. llm_config_id = ( - search_space.agent_llm_id if search_space.agent_llm_id is not None else -1 + search_space.chat_model_id if search_space.chat_model_id is not None else 0 ) # Release the read-transaction so we don't hold ACCESS SHARE locks @@ -2228,7 +2227,7 @@ async def regenerate_response( raise HTTPException(status_code=404, detail="Search space not found") llm_config_id = ( - search_space.agent_llm_id if search_space.agent_llm_id is not None else -1 + search_space.chat_model_id if search_space.chat_model_id is not None else 0 ) # Release the read-transaction so we don't hold ACCESS SHARE locks @@ -2393,7 +2392,7 @@ async def resume_chat( raise HTTPException(status_code=404, detail="Search space not found") llm_config_id = ( - search_space.agent_llm_id if search_space.agent_llm_id is not None else -1 + search_space.chat_model_id if search_space.chat_model_id is not None else 0 ) decisions = [d.model_dump() for d in request.decisions] diff --git a/surfsense_backend/app/routes/new_llm_config_routes.py b/surfsense_backend/app/routes/new_llm_config_routes.py index 531fa8730..adba5b5ae 100644 --- a/surfsense_backend/app/routes/new_llm_config_routes.py +++ b/surfsense_backend/app/routes/new_llm_config_routes.py @@ -57,7 +57,7 @@ def _serialize_byok_config(config: NewLLMConfig) -> NewLLMConfigRead: litellm_params.get("base_model") if isinstance(litellm_params, dict) else None ) supports_image_input = derive_supports_image_input( - litellm_provider=provider_value.lower(), + provider=provider_value.lower(), model_name=config.model_name, base_model=base_model, custom_provider=config.custom_provider, @@ -147,7 +147,7 @@ async def get_global_new_llm_configs( else None ) supports_image_input = derive_supports_image_input( - litellm_provider=cfg.get("litellm_provider"), + provider=cfg.get("provider") or cfg.get("litellm_provider"), model_name=cfg.get("model_name"), base_model=cfg_base_model, custom_provider=cfg.get("custom_provider"), @@ -157,7 +157,7 @@ async def get_global_new_llm_configs( "id": cfg.get("id"), "name": cfg.get("name"), "description": cfg.get("description"), - "provider": cfg.get("litellm_provider"), + "provider": cfg.get("provider") or cfg.get("litellm_provider"), "custom_provider": cfg.get("custom_provider"), "model_name": cfg.get("model_name"), "api_base": cfg.get("api_base") or None, diff --git a/surfsense_backend/app/routes/search_spaces_routes.py b/surfsense_backend/app/routes/search_spaces_routes.py index 2cda04221..7c5fbf28b 100644 --- a/surfsense_backend/app/routes/search_spaces_routes.py +++ b/surfsense_backend/app/routes/search_spaces_routes.py @@ -419,7 +419,7 @@ async def _get_llm_config_by_id( "id": cfg.get("id"), "name": cfg.get("name"), "description": cfg.get("description"), - "provider": cfg.get("litellm_provider"), + "provider": cfg.get("provider") or cfg.get("litellm_provider"), "custom_provider": cfg.get("custom_provider"), "model_name": cfg.get("model_name"), "api_base": cfg.get("api_base"), @@ -490,7 +490,7 @@ async def _get_image_gen_config_by_id( "id": cfg.get("id"), "name": cfg.get("name"), "description": cfg.get("description"), - "provider": cfg.get("litellm_provider"), + "provider": cfg.get("provider") or cfg.get("litellm_provider"), "custom_provider": cfg.get("custom_provider"), "model_name": cfg.get("model_name"), "api_base": cfg.get("api_base") or None, @@ -550,7 +550,7 @@ async def _get_vision_llm_config_by_id( "id": cfg.get("id"), "name": cfg.get("name"), "description": cfg.get("description"), - "provider": cfg.get("litellm_provider"), + "provider": cfg.get("provider") or cfg.get("litellm_provider"), "custom_provider": cfg.get("custom_provider"), "model_name": cfg.get("model_name"), "api_base": cfg.get("api_base") or None, diff --git a/surfsense_backend/app/routes/vision_llm_routes.py b/surfsense_backend/app/routes/vision_llm_routes.py index df218daac..b93d25b9c 100644 --- a/surfsense_backend/app/routes/vision_llm_routes.py +++ b/surfsense_backend/app/routes/vision_llm_routes.py @@ -96,7 +96,7 @@ async def get_global_vision_llm_configs( "id": cfg.get("id"), "name": cfg.get("name"), "description": cfg.get("description"), - "provider": cfg.get("litellm_provider"), + "provider": cfg.get("provider") or cfg.get("litellm_provider"), "custom_provider": cfg.get("custom_provider"), "model_name": cfg.get("model_name"), "api_base": cfg.get("api_base") or None, diff --git a/surfsense_backend/app/schemas/__init__.py b/surfsense_backend/app/schemas/__init__.py index 2a06eca5c..dde9fcef1 100644 --- a/surfsense_backend/app/schemas/__init__.py +++ b/surfsense_backend/app/schemas/__init__.py @@ -49,6 +49,7 @@ from .model_connections import ( ConnectionRead, ConnectionUpdate, ModelCreate, + ModelProviderRead, ModelRead, ModelRolesRead, ModelRolesUpdate, diff --git a/surfsense_backend/app/schemas/model_connections.py b/surfsense_backend/app/schemas/model_connections.py index 306dd63c8..c081a193d 100644 --- a/surfsense_backend/app/schemas/model_connections.py +++ b/surfsense_backend/app/schemas/model_connections.py @@ -4,7 +4,7 @@ from typing import Any from pydantic import BaseModel, ConfigDict, Field -from app.db import ConnectionProtocol, ConnectionScope, ModelSource +from app.db import ConnectionScope, ModelSource class ModelRead(BaseModel): @@ -13,9 +13,11 @@ class ModelRead(BaseModel): model_id: str display_name: str | None = None source: ModelSource | str - capabilities: dict[str, Any] - capabilities_declared: dict[str, Any] = Field(default_factory=dict) - capabilities_verified: dict[str, Any] = Field(default_factory=dict) + supports_chat: bool | None = None + max_input_tokens: int | None = None + supports_image_input: bool | None = None + supports_tools: bool | None = None + supports_image_generation: bool | None = None capabilities_override: dict[str, Any] = Field(default_factory=dict) embedding_dimension: int | None = None enabled: bool @@ -28,8 +30,7 @@ class ModelRead(BaseModel): class ConnectionRead(BaseModel): id: int - protocol: ConnectionProtocol | str - litellm_provider: str | None = None + provider: str base_url: str | None = None extra: dict[str, Any] = Field(default_factory=dict) scope: ConnectionScope | str @@ -47,8 +48,7 @@ class ConnectionRead(BaseModel): class ConnectionCreate(BaseModel): - protocol: ConnectionProtocol - litellm_provider: str | None = Field(None, max_length=100) + provider: str = Field(..., max_length=100) base_url: str | None = Field(None, max_length=500) api_key: str | None = None extra: dict[str, Any] = Field(default_factory=dict) @@ -58,7 +58,7 @@ class ConnectionCreate(BaseModel): class ConnectionUpdate(BaseModel): - litellm_provider: str | None = Field(None, max_length=100) + provider: str | None = Field(None, max_length=100) base_url: str | None = Field(None, max_length=500) api_key: str | None = None extra: dict[str, Any] | None = None @@ -79,9 +79,24 @@ class ModelCreate(BaseModel): class ModelUpdate(BaseModel): display_name: str | None = Field(None, max_length=255) enabled: bool | None = None + supports_chat: bool | None = None + max_input_tokens: int | None = None + supports_image_input: bool | None = None + supports_tools: bool | None = None + supports_image_generation: bool | None = None capabilities_override: dict[str, Any] | None = None +class ModelProviderRead(BaseModel): + provider: str + transport: str + discovery: str + default_base_url: str | None = None + base_url_required: bool + auth_style: str + local_only: bool = False + + class VerifyConnectionResponse(BaseModel): status: str ok: bool diff --git a/surfsense_backend/app/services/auto_model_pin_service.py b/surfsense_backend/app/services/auto_model_pin_service.py index ee8c4b8dc..652029e76 100644 --- a/surfsense_backend/app/services/auto_model_pin_service.py +++ b/surfsense_backend/app/services/auto_model_pin_service.py @@ -27,6 +27,7 @@ from sqlalchemy.orm import selectinload from app.config import config from app.db import Connection, Model, NewChatThread +from app.services.model_capabilities import has_capability from app.services.quality_score import _QUALITY_TOP_K from app.services.token_quota_service import TokenQuotaService @@ -62,18 +63,13 @@ def _is_usable_global_config(cfg: dict) -> bool: return bool( cfg.get("id") is not None and cfg.get("model_name") - and cfg.get("litellm_provider") + and (cfg.get("provider") or cfg.get("litellm_provider")) and cfg.get("api_key") ) def _has_capability(model: dict | Model, capability: str) -> bool: - caps = ( - model.get("capabilities", {}) - if isinstance(model, dict) - else model.capabilities or {} - ) - return bool(caps.get(capability)) + return has_capability(model, capability) def _prune_runtime_cooldowns(now_ts: float | None = None) -> None: @@ -196,7 +192,7 @@ def _cfg_supports_image_input(cfg: dict) -> bool: else None ) return derive_supports_image_input( - litellm_provider=cfg.get("litellm_provider"), + provider=cfg.get("provider") or cfg.get("litellm_provider"), model_name=cfg.get("model_name"), base_model=base_model, custom_provider=cfg.get("custom_provider"), @@ -253,9 +249,13 @@ def _global_candidates( "model_id": model.get("model_id"), "source": "global", "connection": connection, - "capabilities": model.get("capabilities") or {}, + "supports_chat": model.get("supports_chat"), + "supports_image_input": model.get("supports_image_input"), + "supports_tools": model.get("supports_tools"), + "supports_image_generation": model.get("supports_image_generation"), + "capabilities_override": model.get("capabilities_override") or {}, "billing_tier": model.get("billing_tier", "free"), - "litellm_provider": connection.get("litellm_provider"), + "provider": connection.get("provider"), "model_name": model.get("model_id"), "auto_pin_tier": catalog.get("auto_pin_tier") or cfg.get("auto_pin_tier") @@ -310,9 +310,13 @@ async def _db_candidates( "model_id": model.model_id, "source": "db", "connection": conn, - "capabilities": model.capabilities or {}, + "supports_chat": model.supports_chat, + "supports_image_input": model.supports_image_input, + "supports_tools": model.supports_tools, + "supports_image_generation": model.supports_image_generation, + "capabilities_override": model.capabilities_override or {}, "billing_tier": "byok", - "litellm_provider": conn.litellm_provider, + "provider": conn.provider, "model_name": model.model_id, "auto_pin_tier": catalog.get("auto_pin_tier") or "BYOK", "quality_score": catalog.get("quality_score") or 75, @@ -357,7 +361,7 @@ def _is_preferred_premium_auto_config(cfg: dict) -> bool: return ( cfg.get("source") == "global" and _tier_of(cfg) == "premium" - and str(cfg.get("litellm_provider", "")).lower() == "azure" + and str(cfg.get("provider", "")).lower() == "azure" and str(cfg.get("model_name", "")).lower() == "gpt-5.4" ) diff --git a/surfsense_backend/app/services/global_model_catalog.py b/surfsense_backend/app/services/global_model_catalog.py index e40b3a942..ca1249497 100644 --- a/surfsense_backend/app/services/global_model_catalog.py +++ b/surfsense_backend/app/services/global_model_catalog.py @@ -18,8 +18,7 @@ def _connection_key(conn: dict[str, Any]) -> tuple[Any, ...]: # Deliberately includes api_key because two operator-owned credentials for # the same provider/base can have different quota/rate limits upstream. return ( - conn.get("protocol"), - conn.get("litellm_provider"), + conn.get("provider"), conn.get("base_url"), conn.get("api_key"), _freeze(conn.get("extra") or {}), @@ -34,16 +33,6 @@ def _freeze(value: Any) -> Any: return value -def _capabilities_for(role: str, config: dict[str, Any]) -> dict[str, bool]: - return { - "chat": role == "chat", - "vision": role == "vision" or bool(config.get("supports_image_input")), - "image_gen": role == "image_gen", - "embedding": False, - "tools": bool(config.get("supports_tools", False)), - } - - def _catalog_metadata(config: dict[str, Any]) -> dict[str, Any]: return { "billing_tier": config.get("billing_tier", "free"), @@ -72,7 +61,6 @@ def _catalog_metadata(config: dict[str, Any]) -> dict[str, Any]: def materialize_global_model_catalog( *, chat_configs: list[dict[str, Any]], - vision_configs: list[dict[str, Any]], image_configs: list[dict[str, Any]], ) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]: connections: list[dict[str, Any]] = [] @@ -109,9 +97,13 @@ def materialize_global_model_catalog( "model_id": config["model_name"], "display_name": config.get("name") or config["model_name"], "source": "MANUAL", - "capabilities": _capabilities_for(role, config), - "capabilities_declared": _capabilities_for(role, config), - "capabilities_verified": _capabilities_for(role, config), + "supports_chat": role == "chat", + "max_input_tokens": config.get("max_input_tokens"), + "supports_image_input": ( + role == "chat" and bool(config.get("supports_image_input")) + ), + "supports_tools": bool(config.get("supports_tools", False)), + "supports_image_generation": role == "image_gen", "capabilities_override": {}, "embedding_dimension": None, "enabled": True, @@ -125,10 +117,6 @@ def materialize_global_model_catalog( if cfg.get("is_auto_mode"): continue add_config(cfg, "chat") - for cfg in vision_configs: - if cfg.get("is_auto_mode"): - continue - add_config(cfg, "vision") for cfg in image_configs: if cfg.get("is_auto_mode"): continue diff --git a/surfsense_backend/app/services/llm_service.py b/surfsense_backend/app/services/llm_service.py index 86a9c8556..eadb4dbf8 100644 --- a/surfsense_backend/app/services/llm_service.py +++ b/surfsense_backend/app/services/llm_service.py @@ -15,6 +15,7 @@ from app.services.auto_model_pin_service import ( choose_auto_model_candidate, ) from app.services.llm_router_service import AUTO_MODE_ID, ChatLiteLLMRouter, is_auto_mode +from app.services.model_capabilities import has_capability from app.services.model_resolver import native_connection_from_config, to_litellm from app.services.token_tracking_service import token_tracker @@ -76,7 +77,7 @@ def _legacy_config_connection( api_version: str | None = None, ) -> tuple[str, dict]: cfg = { - "litellm_provider": provider.lower(), + "provider": provider.lower(), "model_name": model_name, "api_key": api_key, "api_base": api_base, @@ -136,12 +137,7 @@ def get_global_connection(connection_id: int) -> dict | None: def _has_capability(model: dict | Model, capability: str) -> bool: - caps = ( - model.get("capabilities", {}) - if isinstance(model, dict) - else model.capabilities or {} - ) - return bool(caps.get(capability)) + return has_capability(model, capability) def _chat_litellm_from_resolved( @@ -420,8 +416,6 @@ async def get_vision_llm( unwrapped — they don't consume premium credit (issue M). """ from app.services.quota_checked_vision_llm import QuotaCheckedVisionLLM - from app.services.vision_llm_router_service import is_vision_auto_mode - try: result = await session.execute( select(SearchSpace).where(SearchSpace.id == search_space_id) @@ -468,7 +462,7 @@ async def get_vision_llm( logger.error(f"No vision LLM configured for search space {search_space_id}") return None - if is_vision_auto_mode(config_id): + if config_id == AUTO_MODE_ID: candidates = await auto_model_candidates( session, search_space_id=search_space_id, diff --git a/surfsense_backend/app/services/model_capabilities.py b/surfsense_backend/app/services/model_capabilities.py new file mode 100644 index 000000000..fb7681f35 --- /dev/null +++ b/surfsense_backend/app/services/model_capabilities.py @@ -0,0 +1,36 @@ +"""Override-aware model capability lookup.""" + +from __future__ import annotations + +from collections.abc import Mapping +from typing import Any + +CAPABILITY_FIELDS = { + "chat": "supports_chat", + "vision": "supports_image_input", + "image_gen": "supports_image_generation", + "tools": "supports_tools", +} + + +def _get_value(model: Any, key: str) -> Any: + if isinstance(model, Mapping): + return model.get(key) + return getattr(model, key, None) + + +def has_capability(model: Any, capability: str) -> bool: + field = CAPABILITY_FIELDS.get(capability) + if field is None: + return False + + override = _get_value(model, "capabilities_override") or {} + if isinstance(override, Mapping) and field in override: + return bool(override[field]) + if isinstance(override, Mapping) and capability in override: + return bool(override[capability]) + + return bool(_get_value(model, field)) + + +__all__ = ["CAPABILITY_FIELDS", "has_capability"] diff --git a/surfsense_backend/app/services/model_connection_service.py b/surfsense_backend/app/services/model_connection_service.py index 5e5b231f9..428af736e 100644 --- a/surfsense_backend/app/services/model_connection_service.py +++ b/surfsense_backend/app/services/model_connection_service.py @@ -11,8 +11,10 @@ from typing import Any import httpx import litellm -from app.db import Connection, ConnectionProtocol, Model, ModelSource +from app.db import Connection, Model, ModelSource from app.services.model_resolver import ensure_v1, to_litellm +from app.services.openrouter_model_normalizer import normalize_openrouter_models +from app.services.provider_registry import Transport, spec_for logger = logging.getLogger(__name__) @@ -41,6 +43,16 @@ def _anthropic_headers(conn: Connection) -> dict[str, str]: return headers +def _base_url_or_default(conn: Connection) -> str | None: + if conn.base_url: + return conn.base_url.rstrip("/") + if conn.provider == "openai": + return "https://api.openai.com/v1" + if conn.provider == "anthropic": + return "https://api.anthropic.com/v1" + return spec_for(conn.provider).default_base_url + + def _docker_hint(url: str | None, exc_or_status: Any) -> str: raw = str(exc_or_status) if not url: @@ -61,32 +73,30 @@ def _docker_hint(url: str | None, exc_or_status: Any) -> str: async def verify_connection(conn: Connection) -> VerifyResult: - if not conn.base_url: + spec = spec_for(conn.provider) + base_url = _base_url_or_default(conn) + if spec.base_url_required and not base_url: return VerifyResult("UNREACHABLE", False, "Base URL is required.") - if conn.protocol == ConnectionProtocol.OLLAMA: - url = f"{conn.base_url.rstrip('/')}/api/version" - elif conn.protocol == ConnectionProtocol.OPENAI_COMPATIBLE: - url = f"{ensure_v1(conn.base_url)}/models" - elif conn.protocol == ConnectionProtocol.ANTHROPIC: - url = f"{conn.base_url.rstrip('/')}/models" + if spec.transport == Transport.OLLAMA and base_url: + url = f"{base_url.rstrip('/')}/api/version" + elif spec.discovery in {"openai_models", "openrouter"} and base_url: + url = f"{ensure_v1(base_url)}/models" + elif spec.discovery == "anthropic_models" and base_url: + url = f"{base_url.rstrip('/')}/models" else: - return VerifyResult("UNREACHABLE", False, "Unsupported connection protocol.") + return VerifyResult("OK", True, "Connection uses provider-native authentication.") try: async with httpx.AsyncClient(timeout=VERIFY_TIMEOUT_SECONDS) as client: - headers = ( - _anthropic_headers(conn) - if conn.protocol == ConnectionProtocol.ANTHROPIC - else _auth_headers(conn) - ) + headers = _anthropic_headers(conn) if spec.auth_style == "x-api-key" else _auth_headers(conn) response = await client.get(url, headers=headers) if response.status_code in (401, 403): return VerifyResult("AUTH_FAILED", False, "Authentication failed.") if response.status_code == 404: - if conn.protocol == ConnectionProtocol.OLLAMA and url.endswith("/v1/models"): + if spec.transport == Transport.OLLAMA and url.endswith("/v1/models"): message = "Ollama native API should not use /v1." - elif conn.protocol == ConnectionProtocol.OPENAI_COMPATIBLE: + elif spec.transport == Transport.OPENAI_COMPATIBLE: message = "OpenAI-compatible servers should expose /v1/models." else: message = "Endpoint returned 404." @@ -94,11 +104,11 @@ async def verify_connection(conn: Connection) -> VerifyResult: response.raise_for_status() return VerifyResult("OK", True, "Connection verified.") except httpx.ConnectError as exc: - return VerifyResult("UNREACHABLE", False, _docker_hint(conn.base_url, exc)) + return VerifyResult("UNREACHABLE", False, _docker_hint(base_url, exc)) except httpx.TimeoutException as exc: return VerifyResult("UNREACHABLE", False, f"Connection timed out: {exc}") except httpx.HTTPError as exc: - return VerifyResult("UNREACHABLE", False, _docker_hint(conn.base_url, exc)) + return VerifyResult("UNREACHABLE", False, _docker_hint(base_url, exc)) async def persist_verification(conn: Connection) -> VerifyResult: @@ -109,123 +119,193 @@ async def persist_verification(conn: Connection) -> VerifyResult: return result -def _litellm_capabilities(model_string: str, model_id: str) -> dict[str, bool]: - capabilities = { - "chat": True, - "vision": False, - "tools": False, - "image_gen": False, - "embedding": False, - } - with contextlib.suppress(Exception): - capabilities["vision"] = bool(litellm.supports_vision(model=model_string)) - with contextlib.suppress(Exception): - capabilities["tools"] = bool(litellm.supports_function_calling(model=model_string)) - try: - info = litellm.model_cost.get(model_string) or litellm.model_cost.get(model_id) or {} - mode = str(info.get("mode") or "") - capabilities["embedding"] = mode == "embedding" - capabilities["image_gen"] = mode in {"image_generation", "image_generation_model"} - except Exception: - pass - return capabilities - - def _allowlist(conn: Connection) -> set[str]: - """Per-connection model-id allowlist stored in ``extra.model_ids``. - - Empty/absent means "no restriction" (discover everything), mirroring - OpenWebUI's behaviour. A non-empty list restricts discovery to those ids — - essential for providers like OpenRouter that expose hundreds of models. - """ raw = (conn.extra or {}).get("model_ids") or [] return {str(item).strip() for item in raw if str(item).strip()} -async def _discover_openai_shaped_models(conn: Connection, base_url: str | None) -> list[dict[str, Any]]: - if not base_url: +def _litellm_info(model_string: str, model_id: str) -> dict[str, Any]: + with contextlib.suppress(Exception): + info = litellm.get_model_info(model=model_string) + if isinstance(info, dict): + return info + return litellm.model_cost.get(model_string) or litellm.model_cost.get(model_id) or {} + + +def _classify_from_litellm(model_string: str, model_id: str) -> dict[str, Any]: + info = _litellm_info(model_string, model_id) + mode = info.get("mode") + supports_image_input = False + supports_tools = False + with contextlib.suppress(Exception): + supports_image_input = bool(litellm.supports_vision(model=model_string)) + with contextlib.suppress(Exception): + supports_tools = bool(litellm.supports_function_calling(model=model_string)) + return { + "supports_chat": mode in (None, "chat", "completion", "responses"), + "max_input_tokens": info.get("max_input_tokens") or info.get("max_tokens"), + "supports_image_input": supports_image_input, + "supports_tools": supports_tools, + "supports_image_generation": mode in {"image_generation", "image_generation_model"}, + } + + +def derive_capabilities(conn: Connection, model_id: str, metadata: dict | None = None) -> dict[str, Any]: + metadata = metadata or {} + spec = spec_for(conn.provider) + model_string, _ = to_litellm(conn, model_id) + facts = _classify_from_litellm(model_string, model_id) + if spec.transport == Transport.OLLAMA: + caps = set(metadata.get("capabilities") or []) + details = metadata.get("details") or {} + facts.update( + { + "supports_chat": "embedding" not in caps, + "supports_image_input": "vision" in caps or facts["supports_image_input"], + "supports_tools": "tools" in caps or facts["supports_tools"], + "supports_image_generation": False, + "max_input_tokens": metadata.get("context_length") + or metadata.get("num_ctx") + or details.get("context_length") + or facts["max_input_tokens"], + } + ) + return facts + + +async def _discover_openai_shaped_models( + conn: Connection, base_url: str | None +) -> list[dict[str, Any]]: + resolved_base_url = base_url or _base_url_or_default(conn) + if not resolved_base_url: return [] - url = f"{ensure_v1(base_url)}/models" + url = f"{ensure_v1(resolved_base_url)}/models" async with httpx.AsyncClient(timeout=DISCOVERY_TIMEOUT_SECONDS) as client: response = await client.get(url, headers=_auth_headers(conn)) response.raise_for_status() - return [ - { - "model_id": item.get("id"), - "display_name": item.get("name") or item.get("id"), - "source": ModelSource.DISCOVERED, - "capabilities": derive_capabilities(conn, item.get("id"), item), - "metadata": item, - } - for item in response.json().get("data", []) - if item.get("id") - ] + + results: list[dict[str, Any]] = [] + for item in response.json().get("data", []): + model_id = item.get("id") + if not model_id: + continue + results.append( + { + "model_id": model_id, + "display_name": item.get("name") or model_id, + "source": ModelSource.DISCOVERED, + **derive_capabilities(conn, model_id, item), + "metadata": item, + } + ) + return results async def _discover_anthropic_models(conn: Connection) -> list[dict[str, Any]]: - if not conn.base_url: + base_url = _base_url_or_default(conn) + if not base_url: return [] - url = f"{conn.base_url.rstrip('/')}/models" + url = f"{base_url.rstrip('/')}/models" async with httpx.AsyncClient(timeout=DISCOVERY_TIMEOUT_SECONDS) as client: response = await client.get(url, headers=_anthropic_headers(conn)) response.raise_for_status() - models = response.json().get("data", []) - return [ - { - "model_id": item.get("id"), - "display_name": item.get("display_name") or item.get("id"), - "source": ModelSource.DISCOVERED, - "capabilities": derive_capabilities(conn, item.get("id"), item), - "metadata": item, - } - for item in models - if item.get("id") - ] + + results: list[dict[str, Any]] = [] + for item in response.json().get("data", []): + model_id = item.get("id") + if not model_id: + continue + results.append( + { + "model_id": model_id, + "display_name": item.get("display_name") or model_id, + "source": ModelSource.DISCOVERED, + **derive_capabilities(conn, model_id, item), + "metadata": item, + } + ) + return results -def derive_capabilities(conn: Connection, model_id: str, metadata: dict | None = None) -> dict[str, bool]: - metadata = metadata or {} - if conn.protocol == ConnectionProtocol.OLLAMA: - caps = metadata.get("capabilities") or [] - capabilities = { - "chat": True, - "vision": "vision" in caps, - "tools": False, - "image_gen": False, - "embedding": "embedding" in caps, - } - return capabilities +async def _ollama_tags_then_show(conn: Connection) -> list[dict[str, Any]]: + if not conn.base_url: + return [] - model_string, _ = to_litellm(conn, model_id) - return _litellm_capabilities(model_string, model_id) + base_url = conn.base_url.rstrip("/") + async with httpx.AsyncClient(timeout=DISCOVERY_TIMEOUT_SECONDS) as client: + response = await client.get(f"{base_url}/api/tags", headers=_auth_headers(conn)) + response.raise_for_status() + models = response.json().get("models", []) + results: list[dict[str, Any]] = [] + for item in models: + model_id = item.get("model") or item.get("name") + if not model_id: + continue + metadata = dict(item) + with contextlib.suppress(Exception): + show_response = await client.post( + f"{base_url}/api/show", + json={"model": model_id}, + headers=_auth_headers(conn), + ) + show_response.raise_for_status() + metadata.update(show_response.json()) + results.append( + { + "model_id": model_id, + "display_name": item.get("name") or model_id, + "source": ModelSource.DISCOVERED, + **derive_capabilities(conn, model_id, metadata), + "metadata": metadata, + } + ) + return results + + +async def _openrouter_models(conn: Connection) -> list[dict[str, Any]]: + base_url = _base_url_or_default(conn) or "https://openrouter.ai/api/v1" + async with httpx.AsyncClient(timeout=DISCOVERY_TIMEOUT_SECONDS) as client: + response = await client.get(f"{ensure_v1(base_url)}/models", headers=_auth_headers(conn)) + response.raise_for_status() + return normalize_openrouter_models(response.json().get("data", [])) + + +def _litellm_static_models(conn: Connection) -> list[dict[str, Any]]: + provider = conn.provider + prefix = spec_for(provider).litellm_prefix or provider + results: list[dict[str, Any]] = [] + for model_string, metadata in litellm.model_cost.items(): + if not isinstance(model_string, str) or not model_string.startswith(f"{prefix}/"): + continue + model_id = model_string.split("/", 1)[1] + results.append( + { + "model_id": model_id, + "display_name": metadata.get("display_name") or model_id, + "source": ModelSource.DISCOVERED, + **_classify_from_litellm(model_string, model_id), + "metadata": metadata, + } + ) + return results async def discover_models(conn: Connection) -> list[dict[str, Any]]: allowlist = _allowlist(conn) + spec = spec_for(conn.provider) - if conn.protocol == ConnectionProtocol.OLLAMA: - url = f"{conn.base_url.rstrip('/')}/api/tags" - async with httpx.AsyncClient(timeout=DISCOVERY_TIMEOUT_SECONDS) as client: - response = await client.get(url, headers=_auth_headers(conn)) - response.raise_for_status() - models = response.json().get("models", []) - results = [ - { - "model_id": item.get("model") or item.get("name"), - "display_name": item.get("name") or item.get("model"), - "source": ModelSource.DISCOVERED, - "capabilities": derive_capabilities(conn, item.get("model") or item.get("name"), item), - "metadata": item, - } - for item in models - if item.get("model") or item.get("name") - ] - elif conn.protocol == ConnectionProtocol.OPENAI_COMPATIBLE: - results = await _discover_openai_shaped_models(conn, conn.base_url) - elif conn.protocol == ConnectionProtocol.ANTHROPIC: + if spec.discovery == "ollama": + results = await _ollama_tags_then_show(conn) + elif spec.discovery == "openrouter": + results = await _openrouter_models(conn) + elif spec.discovery == "anthropic_models": results = await _discover_anthropic_models(conn) + elif spec.discovery == "openai_models": + results = await _discover_openai_shaped_models(conn, conn.base_url) + elif spec.discovery == "static": + results = _litellm_static_models(conn) else: results = [] @@ -246,10 +326,7 @@ async def test_model(conn: Connection, model: Model) -> VerifyResult: except Exception as exc: return VerifyResult("UNREACHABLE", False, str(exc)) - model.capabilities_verified = { - **(model.capabilities_verified or {}), - "chat": True, - } + model.supports_chat = True return VerifyResult("OK", True, "Model test succeeded.") diff --git a/surfsense_backend/app/services/model_list_service.py b/surfsense_backend/app/services/model_list_service.py index 1ef0b0c90..0761d7e4f 100644 --- a/surfsense_backend/app/services/model_list_service.py +++ b/surfsense_backend/app/services/model_list_service.py @@ -12,6 +12,8 @@ from pathlib import Path import httpx +from app.services.openrouter_model_normalizer import normalize_openrouter_models + logger = logging.getLogger(__name__) OPENROUTER_API_URL = "https://openrouter.ai/api/v1/models" @@ -121,26 +123,13 @@ def _process_models(raw_models: list[dict]) -> list[dict]: """ processed: list[dict] = [] - for model in raw_models: - model_id: str = model.get("id", "") - name: str = model.get("name", "") - context_length = model.get("context_length") - + for normalized in normalize_openrouter_models(raw_models): + model_id: str = normalized["model_id"] + name: str = normalized.get("display_name") or model_id + context_length = normalized.get("max_input_tokens") if "/" not in model_id: continue - if not _is_text_output_model(model): - continue - - if not _supports_tool_calling(model): - continue - - if not _has_sufficient_context(model): - continue - - if not _is_allowed_model(model): - continue - provider_slug, model_name = model_id.split("/", 1) context_window = _format_context_length(context_length) diff --git a/surfsense_backend/app/services/model_resolver.py b/surfsense_backend/app/services/model_resolver.py index ffa77a9a2..ae6fd2877 100644 --- a/surfsense_backend/app/services/model_resolver.py +++ b/surfsense_backend/app/services/model_resolver.py @@ -12,9 +12,7 @@ from typing import TYPE_CHECKING, Any if TYPE_CHECKING: from app.db import Connection -PROTOCOL_OLLAMA = "OLLAMA" -PROTOCOL_OPENAI_COMPATIBLE = "OPENAI_COMPATIBLE" -PROTOCOL_ANTHROPIC = "ANTHROPIC" +from app.services.provider_registry import Transport, spec_for def ensure_v1(base_url: str | None) -> str | None: @@ -32,47 +30,25 @@ def _conn_value(conn: Connection | Mapping[str, Any], key: str) -> Any: return getattr(conn, key) -def _protocol_value(protocol: Any) -> str: - return getattr(protocol, "value", str(protocol)) - - -def default_litellm_provider(protocol: Any) -> str: - protocol_value = _protocol_value(protocol) - defaults = { - PROTOCOL_OLLAMA: "ollama_chat", - PROTOCOL_ANTHROPIC: "anthropic", - PROTOCOL_OPENAI_COMPATIBLE: "openai", - } - return defaults.get(protocol_value, "openai") - - -def _execution_api_base(protocol: str, base_url: str | None) -> str | None: - del protocol - if not base_url: - return None - return base_url.rstrip("/") - - def to_litellm( conn: Connection | Mapping[str, Any], model_id: str, ) -> tuple[str, dict[str, Any]]: """Return ``(model_string, litellm_kwargs)`` for any model role.""" - protocol = _protocol_value(_conn_value(conn, "protocol")) + provider = _conn_value(conn, "provider") base_url = _conn_value(conn, "base_url") api_key = _conn_value(conn, "api_key") - litellm_provider = ( - _conn_value(conn, "litellm_provider") or default_litellm_provider(protocol) - ) extra = _conn_value(conn, "extra") or {} + spec = spec_for(provider) kwargs: dict[str, Any] = {} if api_key: kwargs["api_key"] = api_key - model_string = f"{litellm_provider}/{model_id}" if litellm_provider else model_id - api_base = _execution_api_base(protocol, base_url) - if api_base: + prefix = spec.litellm_prefix or str(provider) + model_string = f"{prefix}/{model_id}" if prefix else model_id + if base_url: + api_base = ensure_v1(base_url) if spec.transport == Transport.OPENAI_COMPATIBLE else base_url.rstrip("/") kwargs["api_base"] = api_base if api_version := extra.get("api_version"): @@ -84,11 +60,11 @@ def to_litellm( def native_connection_from_config(config: Mapping[str, Any]) -> dict[str, Any]: """Build an in-memory connection mapping from a global config.""" - protocol = str(config.get("protocol") or PROTOCOL_OPENAI_COMPATIBLE) - litellm_provider = str( - config.get("litellm_provider") + provider = str( + config.get("provider") + or config.get("litellm_provider") or config.get("custom_provider") - or default_litellm_provider(protocol) + or "openai" ) extra: dict[str, Any] = { "litellm_params": config.get("litellm_params") or {}, @@ -96,8 +72,7 @@ def native_connection_from_config(config: Mapping[str, Any]) -> dict[str, Any]: if config.get("api_version"): extra["api_version"] = config.get("api_version") return { - "protocol": protocol, - "litellm_provider": litellm_provider, + "provider": provider, "base_url": config.get("api_base") or None, "api_key": config.get("api_key") or None, "extra": extra, @@ -105,7 +80,6 @@ def native_connection_from_config(config: Mapping[str, Any]) -> dict[str, Any]: __all__ = [ - "default_litellm_provider", "ensure_v1", "native_connection_from_config", "to_litellm", diff --git a/surfsense_backend/app/services/openrouter_integration_service.py b/surfsense_backend/app/services/openrouter_integration_service.py index 6996f0fde..fbb70eb5a 100644 --- a/surfsense_backend/app/services/openrouter_integration_service.py +++ b/surfsense_backend/app/services/openrouter_integration_service.py @@ -29,6 +29,13 @@ from app.services.quality_score import ( aggregate_health, static_score_or, ) +from app.services.openrouter_model_normalizer import ( + is_allowed_model as _shared_is_allowed_model, + is_compatible_provider as _shared_is_compatible_provider, + is_openrouter_image_model, + normalize_openrouter_models, + supports_image_input, +) logger = logging.getLogger(__name__) @@ -292,24 +299,16 @@ def _generate_configs( use_default: bool = settings.get("use_default_system_instructions", True) citations_enabled: bool = settings.get("citations_enabled", True) - text_models = [ - m - for m in raw_models - if _is_text_output_model(m) - and _supports_tool_calling(m) - and _has_sufficient_context(m) - and _is_compatible_provider(m) - and _is_allowed_model(m) - and "/" in m.get("id", "") - ] + text_models = normalize_openrouter_models(raw_models) configs: list[dict] = [] taken: set[int] = set() now_ts = int(time.time()) - for model in text_models: - model_id: str = model["id"] - name: str = model.get("name", model_id) + for normalized in text_models: + model = normalized.get("metadata") or {} + model_id: str = normalized["model_id"] + name: str = normalized.get("display_name") or model_id tier = _openrouter_tier(model) static_q = static_score_or(model, now_ts=now_ts) @@ -323,7 +322,7 @@ def _generate_configs( "seo_enabled": seo_enabled, "seo_slug": None, "quota_reserve_tokens": quota_reserve_tokens, - "litellm_provider": "openrouter", + "provider": "openrouter", "model_name": model_id, "api_key": api_key, "api_base": "https://openrouter.ai/api/v1", @@ -345,7 +344,7 @@ def _generate_configs( # ``stream_new_chat`` as a fail-fast safety net before the # OpenRouter request would otherwise 404 with # ``"No endpoints found that support image input"``. - "supports_image_input": _supports_image_input(model), + "supports_image_input": bool(normalized.get("supports_image_input")), _OPENROUTER_DYNAMIC_MARKER: True, # Auto (Fastest) ranking metadata. ``quality_score`` is initialised # to the static score and gets re-blended with health on the next @@ -403,10 +402,7 @@ def _generate_image_gen_configs( image_models = [ m for m in raw_models - if _is_image_output_model(m) - and _is_compatible_provider(m) - and _is_allowed_model(m) - and "/" in m.get("id", "") + if is_openrouter_image_model(m) ] configs: list[dict] = [] @@ -420,7 +416,7 @@ def _generate_image_gen_configs( "id": _stable_config_id(model_id, id_offset, taken), "name": name, "description": f"{name} via OpenRouter (image generation)", - "litellm_provider": "openrouter", + "provider": "openrouter", "model_name": model_id, "api_key": api_key, "api_base": "https://openrouter.ai/api/v1", @@ -468,9 +464,9 @@ def _generate_vision_llm_configs( vision_models = [ m for m in raw_models - if _is_vision_input_model(m) - and _is_compatible_provider(m) - and _is_allowed_model(m) + if supports_image_input(m) + and _shared_is_compatible_provider(m) + and _shared_is_allowed_model(m) and "/" in m.get("id", "") ] @@ -499,7 +495,7 @@ def _generate_vision_llm_configs( "id": _stable_config_id(model_id, id_offset, taken), "name": name, "description": f"{name} via OpenRouter (vision)", - "litellm_provider": "openrouter", + "provider": "openrouter", "model_name": model_id, "api_key": api_key, "api_base": "https://openrouter.ai/api/v1", @@ -544,11 +540,9 @@ class OpenRouterIntegrationService: # Cached raw catalogue from the most recent fetch. Image / vision # emitters reuse this to avoid a second network call per surface. self._raw_models: list[dict] = [] - # Image / vision config caches (only populated when the matching - # opt-in flag is true on initialize). Refreshed in lockstep with - # the chat catalogue. + # Image config cache (only populated when the matching opt-in flag is + # true on initialize). Refreshed in lockstep with the chat catalogue. self._image_configs: list[dict] = [] - self._vision_configs: list[dict] = [] @classmethod def get_instance(cls) -> "OpenRouterIntegrationService": @@ -583,7 +577,7 @@ class OpenRouterIntegrationService: self._configs_by_id = {c["id"]: c for c in self._configs} self._raw_pricing = _extract_raw_pricing(raw_models) - # Populate image / vision caches when their opt-in flag is set. + # Populate image cache when its opt-in flag is set. # Empty otherwise so the accessors return [] without re-running # filters every refresh. if settings.get("image_generation_enabled"): @@ -595,15 +589,6 @@ class OpenRouterIntegrationService: else: self._image_configs = [] - if settings.get("vision_enabled"): - self._vision_configs = _generate_vision_llm_configs(raw_models, settings) - logger.info( - "OpenRouter integration: vision LLM emission ON (%d models)", - len(self._vision_configs), - ) - else: - self._vision_configs = [] - self._initialized = True tier_counts = self._tier_counts(self._configs) @@ -657,9 +642,9 @@ class OpenRouterIntegrationService: self._configs = new_configs self._configs_by_id = new_by_id - # Image / vision lists are atomic-swapped the same way: filter out + # Image list is atomic-swapped the same way: filter out # the previous dynamic entries from the live config list and append - # the freshly generated ones. No-ops when the opt-in flag is off. + # the freshly generated ones. No-op when the opt-in flag is off. if self._settings.get("image_generation_enabled"): new_image = _generate_image_gen_configs(raw_models, self._settings) static_image = [ @@ -670,16 +655,6 @@ class OpenRouterIntegrationService: app_config.GLOBAL_IMAGE_GEN_CONFIGS = static_image + new_image self._image_configs = new_image - if self._settings.get("vision_enabled"): - new_vision = _generate_vision_llm_configs(raw_models, self._settings) - static_vision = [ - c - for c in app_config.GLOBAL_VISION_LLM_CONFIGS - if not c.get(_OPENROUTER_DYNAMIC_MARKER) - ] - app_config.GLOBAL_VISION_LLM_CONFIGS = static_vision + new_vision - self._vision_configs = new_vision - # Catalogue churn invalidates per-config "recently healthy" credit # earned by the previous turn's preflight. Drop the whole table so # the next turn re-probes against the freshly loaded configs. @@ -701,7 +676,7 @@ class OpenRouterIntegrationService: ) # Re-blend health scores against the freshly fetched catalogue. Also - # re-stamps health for any YAML-curated cfg with litellm_provider=openrouter + # re-stamps health for any YAML-curated cfg with provider=openrouter # so a hand-picked dead OR model is gated like a dynamic one. await self._enrich_health_safely(static_configs + new_configs, log_summary=True) @@ -778,7 +753,7 @@ class OpenRouterIntegrationService: the entire previous cycle's cache for this run. """ or_cfgs = [ - c for c in configs if str(c.get("litellm_provider", "")).lower() == "openrouter" + c for c in configs if str(c.get("provider", "")).lower() == "openrouter" ] if not or_cfgs: return @@ -959,17 +934,6 @@ class OpenRouterIntegrationService: """ return list(self._image_configs) - def get_vision_llm_configs(self) -> list[dict]: - """Return the dynamic OpenRouter vision-LLM configs (empty list - when the ``vision_enabled`` flag is off). - - Each entry exposes ``input_cost_per_token`` / ``output_cost_per_token`` - so ``pricing_registration`` can teach LiteLLM the cost of these - models the same way it does for chat — which keeps the billable - wrapper able to debit accurate micro-USD on a vision call. - """ - return list(self._vision_configs) - def get_raw_pricing(self) -> dict[str, dict[str, str]]: """Return the cached raw OpenRouter pricing map. diff --git a/surfsense_backend/app/services/openrouter_model_normalizer.py b/surfsense_backend/app/services/openrouter_model_normalizer.py new file mode 100644 index 000000000..0b646f933 --- /dev/null +++ b/surfsense_backend/app/services/openrouter_model_normalizer.py @@ -0,0 +1,121 @@ +"""Shared OpenRouter model normalization. + +OpenRouter metadata is richer than generic OpenAI-compatible ``/models`` +responses. Keep all OpenRouter filtering and capability extraction here so +GLOBAL catalogue generation and BYOK discovery agree. +""" + +from __future__ import annotations + +from typing import Any + +from app.db import ModelSource + +MIN_CONTEXT_LENGTH = 100_000 + +EXCLUDED_PROVIDER_SLUGS = {"amazon"} +EXCLUDED_MODEL_IDS: set[str] = { + "openai/gpt-4-1106-preview", + "openai/gpt-4-turbo-preview", + "openai/gpt-4o:extended", + "arcee-ai/virtuoso-large", + "openai/o3-deep-research", + "openai/o4-mini-deep-research", + "openrouter/free", +} +EXCLUDED_MODEL_SUFFIXES: tuple[str, ...] = ("-deep-research",) + + +def is_text_output_model(model: dict[str, Any]) -> bool: + output_mods = model.get("architecture", {}).get("output_modalities", []) + return output_mods == ["text"] + + +def is_image_output_model(model: dict[str, Any]) -> bool: + output_mods = model.get("architecture", {}).get("output_modalities", []) or [] + return "image" in output_mods + + +def supports_image_input(model: dict[str, Any]) -> bool: + input_mods = model.get("architecture", {}).get("input_modalities", []) or [] + return "image" in input_mods + + +def supports_tool_calling(model: dict[str, Any]) -> bool: + supported = model.get("supported_parameters") or [] + return "tools" in supported + + +def has_sufficient_context(model: dict[str, Any]) -> bool: + return int(model.get("context_length") or 0) >= MIN_CONTEXT_LENGTH + + +def is_compatible_provider(model: dict[str, Any]) -> bool: + model_id = str(model.get("id") or "") + slug = model_id.split("/", 1)[0] if "/" in model_id else "" + return slug not in EXCLUDED_PROVIDER_SLUGS + + +def is_allowed_model(model: dict[str, Any]) -> bool: + model_id = str(model.get("id") or "") + if model_id in EXCLUDED_MODEL_IDS: + return False + base_id = model_id.split(":")[0] + return not base_id.endswith(EXCLUDED_MODEL_SUFFIXES) + + +def is_openrouter_chat_model(model: dict[str, Any]) -> bool: + return ( + "/" in str(model.get("id") or "") + and is_text_output_model(model) + and supports_tool_calling(model) + and has_sufficient_context(model) + and is_compatible_provider(model) + and is_allowed_model(model) + ) + + +def is_openrouter_image_model(model: dict[str, Any]) -> bool: + return ( + "/" in str(model.get("id") or "") + and is_image_output_model(model) + and is_compatible_provider(model) + and is_allowed_model(model) + ) + + +def normalize_openrouter_models(raw_models: list[dict[str, Any]]) -> list[dict[str, Any]]: + normalized: list[dict[str, Any]] = [] + for model in raw_models: + if not is_openrouter_chat_model(model): + continue + model_id = str(model.get("id") or "") + normalized.append( + { + "model_id": model_id, + "display_name": model.get("name") or model_id, + "source": ModelSource.DISCOVERED, + "supports_chat": True, + "max_input_tokens": model.get("context_length"), + "supports_image_input": supports_image_input(model), + "supports_tools": supports_tool_calling(model), + "supports_image_generation": False, + "metadata": model, + } + ) + return normalized + + +__all__ = [ + "MIN_CONTEXT_LENGTH", + "has_sufficient_context", + "is_allowed_model", + "is_compatible_provider", + "is_image_output_model", + "is_openrouter_chat_model", + "is_openrouter_image_model", + "is_text_output_model", + "normalize_openrouter_models", + "supports_image_input", + "supports_tool_calling", +] diff --git a/surfsense_backend/app/services/pricing_registration.py b/surfsense_backend/app/services/pricing_registration.py index 6b99fe723..9e4e3b552 100644 --- a/surfsense_backend/app/services/pricing_registration.py +++ b/surfsense_backend/app/services/pricing_registration.py @@ -143,7 +143,7 @@ def _register_chat_shape_configs( sample_keys: list[str] = [] for cfg in configs: - provider = str(cfg.get("litellm_provider") or "").lower() + provider = str(cfg.get("provider") or cfg.get("litellm_provider") or "").lower() model_name = str(cfg.get("model_name") or "").strip() litellm_params = cfg.get("litellm_params") or {} base_model = str(litellm_params.get("base_model") or model_name).strip() @@ -216,9 +216,8 @@ def _register_chat_shape_configs( def register_pricing_from_global_configs() -> None: """Register pricing for every known LLM deployment with LiteLLM. - Walks ``config.GLOBAL_LLM_CONFIGS`` *and* ``config.GLOBAL_VISION_LLM_CONFIGS`` - so vision calls (during indexing) can resolve cost the same way chat - calls do — namely: + Walks ``config.GLOBAL_LLM_CONFIGS`` so chat and vision calls can resolve + cost from the same chat-shaped deployment configs: 1. ``OPENROUTER``: pulls the cached raw pricing from ``OpenRouterIntegrationService`` (populated during its own @@ -245,10 +244,7 @@ def register_pricing_from_global_configs() -> None: from app.config import config as app_config chat_configs: list[dict] = list(getattr(app_config, "GLOBAL_LLM_CONFIGS", []) or []) - vision_configs: list[dict] = list( - getattr(app_config, "GLOBAL_VISION_LLM_CONFIGS", []) or [] - ) - if not chat_configs and not vision_configs: + if not chat_configs: logger.info("[PricingRegistration] no global configs to register") return @@ -267,7 +263,3 @@ def register_pricing_from_global_configs() -> None: if chat_configs: _register_chat_shape_configs(chat_configs, or_pricing=or_pricing, label="chat") - if vision_configs: - _register_chat_shape_configs( - vision_configs, or_pricing=or_pricing, label="vision" - ) diff --git a/surfsense_backend/app/services/provider_capabilities.py b/surfsense_backend/app/services/provider_capabilities.py index 9521ef7a4..fae283ab6 100644 --- a/surfsense_backend/app/services/provider_capabilities.py +++ b/surfsense_backend/app/services/provider_capabilities.py @@ -51,7 +51,7 @@ logger = logging.getLogger(__name__) def _candidate_model_strings( *, - litellm_provider: str | None, + provider: str | None, model_name: str | None, base_model: str | None, custom_provider: str | None, @@ -78,7 +78,7 @@ def _candidate_model_strings( seen.add(key) candidates.append(key) - provider_prefix = custom_provider or litellm_provider + provider_prefix = custom_provider or provider primary_model = base_model or model_name bare_model = model_name @@ -113,7 +113,7 @@ def _candidate_model_strings( def derive_supports_image_input( *, - litellm_provider: str | None = None, + provider: str | None = None, model_name: str | None = None, base_model: str | None = None, custom_provider: str | None = None, @@ -147,7 +147,7 @@ def derive_supports_image_input( return False for model_string, custom_llm_provider in _candidate_model_strings( - litellm_provider=litellm_provider, + provider=provider, model_name=model_name, base_model=base_model, custom_provider=custom_provider, @@ -172,7 +172,7 @@ def derive_supports_image_input( def is_known_text_only_chat_model( *, - litellm_provider: str | None = None, + provider: str | None = None, model_name: str | None = None, base_model: str | None = None, custom_provider: str | None = None, @@ -193,7 +193,7 @@ def is_known_text_only_chat_model( leads to the regression we're fixing here. """ for model_string, custom_llm_provider in _candidate_model_strings( - litellm_provider=litellm_provider, + provider=provider, model_name=model_name, base_model=base_model, custom_provider=custom_provider, diff --git a/surfsense_backend/app/services/provider_registry.py b/surfsense_backend/app/services/provider_registry.py new file mode 100644 index 000000000..871769f11 --- /dev/null +++ b/surfsense_backend/app/services/provider_registry.py @@ -0,0 +1,98 @@ +"""Provider registry for model connections. + +The provider string is the single public identity axis. This registry only +describes providers whose behavior differs from LiteLLM's native default. +""" + +from __future__ import annotations + +from dataclasses import dataclass +from enum import StrEnum +from typing import Literal + + +class Transport(StrEnum): + NATIVE = "NATIVE" + OPENAI_COMPATIBLE = "OPENAI_COMPATIBLE" + OLLAMA = "OLLAMA" + + +DiscoveryKind = Literal[ + "ollama", + "openai_models", + "anthropic_models", + "openrouter", + "static", + "none", +] + +AuthStyle = Literal["bearer", "x-api-key", "none", "native"] + + +@dataclass(frozen=True) +class ProviderSpec: + transport: Transport + litellm_prefix: str | None + discovery: DiscoveryKind + default_base_url: str | None + base_url_required: bool + auth_style: AuthStyle + + +REGISTRY: dict[str, ProviderSpec] = { + "openai": ProviderSpec( + Transport.NATIVE, "openai", "openai_models", None, False, "bearer" + ), + "anthropic": ProviderSpec( + Transport.NATIVE, "anthropic", "anthropic_models", None, False, "x-api-key" + ), + "azure": ProviderSpec(Transport.NATIVE, "azure", "static", None, True, "native"), + "vertex_ai": ProviderSpec( + Transport.NATIVE, "vertex_ai", "static", None, False, "native" + ), + "bedrock": ProviderSpec( + Transport.NATIVE, "bedrock", "static", None, False, "native" + ), + "openrouter": ProviderSpec( + Transport.OPENAI_COMPATIBLE, + "openrouter", + "openrouter", + "https://openrouter.ai/api/v1", + False, + "bearer", + ), + "openai_compatible": ProviderSpec( + Transport.OPENAI_COMPATIBLE, + "openai", + "openai_models", + None, + True, + "bearer", + ), + "lm_studio": ProviderSpec( + Transport.OPENAI_COMPATIBLE, + "openai", + "openai_models", + "http://localhost:1234/v1", + True, + "bearer", + ), + "ollama_chat": ProviderSpec( + Transport.OLLAMA, + "ollama_chat", + "ollama", + "http://localhost:11434", + True, + "none", + ), +} + + +def spec_for(provider: str | None) -> ProviderSpec: + provider_key = (provider or "").strip() + return REGISTRY.get(provider_key) or ProviderSpec( + Transport.NATIVE, provider_key or "openai", "static", None, False, "native" + ) + + +__all__ = ["REGISTRY", "ProviderSpec", "Transport", "spec_for"] diff --git a/surfsense_backend/app/services/quality_score.py b/surfsense_backend/app/services/quality_score.py index 95484439b..9cc9c21ac 100644 --- a/surfsense_backend/app/services/quality_score.py +++ b/surfsense_backend/app/services/quality_score.py @@ -273,7 +273,7 @@ def static_score_yaml(cfg: dict) -> int: listed this model. Pricing / context fall through to lazy ``litellm`` lookups; failures are silent (we just lose those sub-points). """ - provider = str(cfg.get("litellm_provider", "")).lower() + provider = str(cfg.get("provider") or cfg.get("litellm_provider") or "").lower() base = PROVIDER_PRESTIGE_YAML.get(provider, 15) model_name = cfg.get("model_name") or "" diff --git a/surfsense_backend/app/tasks/chat/streaming/flows/new_chat/llm_capability.py b/surfsense_backend/app/tasks/chat/streaming/flows/new_chat/llm_capability.py index f6fcf75d7..69b9f4ab8 100644 --- a/surfsense_backend/app/tasks/chat/streaming/flows/new_chat/llm_capability.py +++ b/surfsense_backend/app/tasks/chat/streaming/flows/new_chat/llm_capability.py @@ -40,7 +40,7 @@ def check_image_input_capability( else None ) if not is_known_text_only_chat_model( - litellm_provider=agent_config.provider, + provider=agent_config.provider, model_name=agent_config.model_name, base_model=agent_base_model, custom_provider=agent_config.custom_provider, diff --git a/surfsense_backend/app/tasks/chat/streaming/flows/shared/llm_bundle.py b/surfsense_backend/app/tasks/chat/streaming/flows/shared/llm_bundle.py index f6870f5fa..cfd50950e 100644 --- a/surfsense_backend/app/tasks/chat/streaming/flows/shared/llm_bundle.py +++ b/surfsense_backend/app/tasks/chat/streaming/flows/shared/llm_bundle.py @@ -22,6 +22,7 @@ from app.agents.chat.runtime.llm_config import ( ) from app.config import config from app.db import Model, SearchSpace +from app.services.model_capabilities import has_capability from app.services.model_resolver import to_litellm @@ -96,7 +97,7 @@ async def load_llm_bundle( model_id=config_id, search_space=search_space, ) - if not model or not (model.capabilities or {}).get("chat"): + if not model or not has_capability(model, "chat"): return ( None, None, @@ -106,12 +107,12 @@ async def load_llm_bundle( agent_config = _agent_config_from_resolved( config_id=config_id, config_name=model.display_name or model.model_id, - provider=model.connection.litellm_provider or "", + provider=model.connection.provider or "", model_name=model.model_id, api_key=model.connection.api_key, api_base=model.connection.base_url, litellm_params=(model.connection.extra or {}).get("litellm_params"), - supports_image_input=bool((model.capabilities or {}).get("vision")), + supports_image_input=has_capability(model, "vision"), billing_tier="free", ) return ( @@ -121,7 +122,7 @@ async def load_llm_bundle( ) global_model = next((m for m in config.GLOBAL_MODELS if m.get("id") == config_id), None) - if not global_model or not (global_model.get("capabilities") or {}).get("chat"): + if not global_model or not has_capability(global_model, "chat"): return None, None, f"Failed to load global chat model with id {config_id}" global_connection = next( ( @@ -137,12 +138,12 @@ async def load_llm_bundle( agent_config = _agent_config_from_resolved( config_id=config_id, config_name=global_model.get("display_name") or global_model.get("model_id"), - provider=global_connection.get("litellm_provider") or "", + provider=global_connection.get("provider") or "", model_name=global_model["model_id"], api_key=global_connection.get("api_key"), api_base=global_connection.get("base_url"), litellm_params=(global_connection.get("extra") or {}).get("litellm_params"), - supports_image_input=bool((global_model.get("capabilities") or {}).get("vision")), + supports_image_input=has_capability(global_model, "vision"), billing_tier=str(global_model.get("billing_tier", "free")).lower(), ) return ( From 3dd54230e72a26027dafa96038aa762aa30cbb06 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Fri, 12 Jun 2026 02:17:37 +0530 Subject: [PATCH 023/212] fix(chat): normalize provider-safe message history --- .../app/tasks/chat/llm_history_normalizer.py | 89 +++++++++++++++++++ .../tasks/chat/message_parts_normalizer.py | 86 ++++++++++++++++++ .../tasks/chat/streaming/agent/event_loop.py | 6 ++ .../flows/shared/assistant_finalize.py | 5 ++ .../chat/streaming/shared/stream_result.py | 4 + surfsense_backend/app/utils/content_utils.py | 28 ++++-- .../chat/runtime/test_llm_config_sanitizer.py | 40 +++++++++ .../tasks/chat/test_llm_history_normalizer.py | 62 +++++++++++++ .../chat/test_message_parts_normalizer.py | 68 ++++++++++++++ 9 files changed, 382 insertions(+), 6 deletions(-) create mode 100644 surfsense_backend/app/tasks/chat/llm_history_normalizer.py create mode 100644 surfsense_backend/app/tasks/chat/message_parts_normalizer.py create mode 100644 surfsense_backend/tests/unit/agents/chat/runtime/test_llm_config_sanitizer.py create mode 100644 surfsense_backend/tests/unit/tasks/chat/test_llm_history_normalizer.py create mode 100644 surfsense_backend/tests/unit/tasks/chat/test_message_parts_normalizer.py diff --git a/surfsense_backend/app/tasks/chat/llm_history_normalizer.py b/surfsense_backend/app/tasks/chat/llm_history_normalizer.py new file mode 100644 index 000000000..e3470d06b --- /dev/null +++ b/surfsense_backend/app/tasks/chat/llm_history_normalizer.py @@ -0,0 +1,89 @@ +"""Convert persisted chat content into provider-safe LangChain history. + +Assistant UI parts are a UI/storage shape, not an LLM prompt shape. This module +extracts only model-safe content before prior turns are replayed to a provider. +""" + +from __future__ import annotations + +from typing import Any + +_USER_CONTENT_TYPES = {"text", "image", "image_url"} + + +def _text_from_block(block: dict[str, Any]) -> str: + value = block.get("text") or block.get("content") or "" + return value if isinstance(value, str) else "" + + +def assistant_content_to_llm_text(content: Any) -> str: + """Return visible assistant text, dropping reasoning/UI/provider blocks.""" + if isinstance(content, str): + return content + if isinstance(content, dict): + return _text_from_block(content) + if not isinstance(content, list): + return "" + + text_chunks: list[str] = [] + for block in content: + if isinstance(block, str): + if block: + text_chunks.append(block) + continue + if not isinstance(block, dict): + continue + if block.get("type") == "text": + text = _text_from_block(block) + if text: + text_chunks.append(text) + return "\n".join(text_chunks) + + +def user_content_to_llm_content( + content: Any, + *, + allow_images: bool = True, +) -> str | list[dict[str, Any]]: + """Return provider-safe user text/image content for LangChain.""" + if isinstance(content, str): + return content + if isinstance(content, dict): + return _text_from_block(content) + if not isinstance(content, list): + return "" + + parts: list[dict[str, Any]] = [] + text_chunks: list[str] = [] + for block in content: + if isinstance(block, str): + if block: + text_chunks.append(block) + continue + if not isinstance(block, dict): + continue + block_type = block.get("type") + if block_type not in _USER_CONTENT_TYPES: + continue + if block_type == "text": + text = _text_from_block(block) + if text: + parts.append({"type": "text", "text": text}) + text_chunks.append(text) + elif allow_images and block_type == "image": + image = block.get("image") + if isinstance(image, str) and image.startswith("data:"): + parts.append({"type": "image_url", "image_url": {"url": image}}) + elif allow_images and block_type == "image_url": + image_url = block.get("image_url") + if isinstance(image_url, dict): + url = image_url.get("url") + if isinstance(url, str) and url.startswith("data:"): + parts.append({"type": "image_url", "image_url": {"url": url}}) + elif isinstance(image_url, str) and image_url.startswith("data:"): + parts.append({"type": "image_url", "image_url": {"url": image_url}}) + + if allow_images and any(part.get("type") == "image_url" for part in parts): + return parts + return "\n".join(text_chunks) + diff --git a/surfsense_backend/app/tasks/chat/message_parts_normalizer.py b/surfsense_backend/app/tasks/chat/message_parts_normalizer.py new file mode 100644 index 000000000..953282e9f --- /dev/null +++ b/surfsense_backend/app/tasks/chat/message_parts_normalizer.py @@ -0,0 +1,86 @@ +"""Normalize final LangChain assistant messages into assistant-ui parts. + +Live streaming remains the primary source for rich, incremental UI state. +This module is only used after the graph has finished so refresh persistence +does not depend on provider-specific streaming chunk shapes. +""" + +from __future__ import annotations + +from collections.abc import Iterable +from typing import Any + +from langchain_core.messages import AIMessage + + +def _text_from_content(content: Any) -> str: + if isinstance(content, str): + return content + if not isinstance(content, list): + return "" + + text_parts: list[str] = [] + for block in content: + if not isinstance(block, dict): + continue + if block.get("type") != "text": + continue + value = block.get("text") or block.get("content") or "" + if isinstance(value, str) and value: + text_parts.append(value) + return "".join(text_parts) + + +def normalize_ai_message_to_parts(message: AIMessage | Any | None) -> list[dict[str, Any]]: + """Return user-visible assistant-ui parts for a final AI message. + + We intentionally do not backfill provider ``thinking`` / + ``reasoning_content`` blocks here. If reasoning streamed live, the + ``AssistantContentBuilder`` already captured it. If it only exists in the + final model payload, persisting it retroactively could expose content the + UI never showed during the turn. + """ + if message is None: + return [] + + text = _text_from_content(getattr(message, "content", None)).strip() + if not text: + return [] + return [{"type": "text", "text": text}] + + +def last_ai_message(messages: Iterable[Any] | None) -> AIMessage | Any | None: + if messages is None: + return None + for message in reversed(list(messages)): + if isinstance(message, AIMessage): + return message + if getattr(message, "type", None) == "ai": + return message + return None + + +def final_assistant_parts_from_messages(messages: Iterable[Any] | None) -> list[dict[str, Any]]: + return normalize_ai_message_to_parts(last_ai_message(messages)) + + +def has_non_empty_text_part(parts: Iterable[dict[str, Any]]) -> bool: + return any( + part.get("type") == "text" + and isinstance(part.get("text"), str) + and bool(part.get("text", "").strip()) + for part in parts + ) + + +def merge_streamed_and_final_parts( + streamed_parts: list[dict[str, Any]], + final_parts: list[dict[str, Any]], +) -> list[dict[str, Any]]: + """Use final-state text only when streaming captured no answer text.""" + if has_non_empty_text_part(streamed_parts): + return streamed_parts + if not has_non_empty_text_part(final_parts): + return streamed_parts + return [*streamed_parts, *final_parts] + diff --git a/surfsense_backend/app/tasks/chat/streaming/agent/event_loop.py b/surfsense_backend/app/tasks/chat/streaming/agent/event_loop.py index d96144bcd..fb7548818 100644 --- a/surfsense_backend/app/tasks/chat/streaming/agent/event_loop.py +++ b/surfsense_backend/app/tasks/chat/streaming/agent/event_loop.py @@ -25,6 +25,9 @@ from app.tasks.chat.streaming.graph_stream.event_stream import stream_output from app.tasks.chat.streaming.helpers.interrupt_inspector import ( all_interrupt_values, ) +from app.tasks.chat.message_parts_normalizer import ( + final_assistant_parts_from_messages, +) from app.tasks.chat.streaming.shared.stream_result import StreamResult from app.tasks.chat.streaming.shared.utils import safe_float from app.utils.perf import get_perf_logger @@ -75,6 +78,9 @@ async def stream_agent_events( state = await agent.aget_state(config) state_values = getattr(state, "values", {}) or {} + result.final_message_parts = final_assistant_parts_from_messages( + state_values.get("messages") + ) # Safety net: if astream_events was cancelled before # KnowledgeBasePersistenceMiddleware.aafter_agent ran, any staged work diff --git a/surfsense_backend/app/tasks/chat/streaming/flows/shared/assistant_finalize.py b/surfsense_backend/app/tasks/chat/streaming/flows/shared/assistant_finalize.py index be1f102f3..3f767c60b 100644 --- a/surfsense_backend/app/tasks/chat/streaming/flows/shared/assistant_finalize.py +++ b/surfsense_backend/app/tasks/chat/streaming/flows/shared/assistant_finalize.py @@ -53,6 +53,7 @@ async def finalize_assistant_message( ): return + from app.tasks.chat.message_parts_normalizer import merge_streamed_and_final_parts from app.tasks.chat.persistence import finalize_assistant_turn builder_stats: dict[str, int] | None = None @@ -74,6 +75,10 @@ async def finalize_assistant_message( "text": stream_result.accumulated_text or "", } ] + content_payload = merge_streamed_and_final_parts( + content_payload, + stream_result.final_message_parts, + ) if builder_stats is not None: _perf_log.info( diff --git a/surfsense_backend/app/tasks/chat/streaming/shared/stream_result.py b/surfsense_backend/app/tasks/chat/streaming/shared/stream_result.py index a940e8a9f..5e164070a 100644 --- a/surfsense_backend/app/tasks/chat/streaming/shared/stream_result.py +++ b/surfsense_backend/app/tasks/chat/streaming/shared/stream_result.py @@ -35,3 +35,7 @@ class StreamResult: # (``StreamResult`` is logged in some error branches) from dumping a # potentially-large parts list. content_builder: Any | None = field(default=None, repr=False) + # User-visible assistant message parts derived from the final LangGraph + # state. Used after streaming completes as a provider-agnostic persistence + # backfill when no text chunks reached the live stream. + final_message_parts: list[dict[str, Any]] = field(default_factory=list) diff --git a/surfsense_backend/app/utils/content_utils.py b/surfsense_backend/app/utils/content_utils.py index 05a4610c7..aae936888 100644 --- a/surfsense_backend/app/utils/content_utils.py +++ b/surfsense_backend/app/utils/content_utils.py @@ -18,6 +18,11 @@ from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import selectinload +from app.tasks.chat.llm_history_normalizer import ( + assistant_content_to_llm_text, + user_content_to_llm_content, +) + if TYPE_CHECKING: from app.db import ChatVisibility @@ -95,17 +100,28 @@ async def bootstrap_history_from_db( langchain_messages: list[HumanMessage | AIMessage] = [] for msg in db_messages: - text_content = extract_text_content(msg.content) - if not text_content: - continue if msg.role == "user": + user_content = user_content_to_llm_content( + msg.content, + allow_images=False, + ) + if not user_content: + continue if is_shared: author_name = ( msg.author.display_name if msg.author else None ) or "A team member" - text_content = f"**[{author_name}]:** {text_content}" - langchain_messages.append(HumanMessage(content=text_content)) + if isinstance(user_content, str): + user_content = f"**[{author_name}]:** {user_content}" + elif user_content and user_content[0].get("type") == "text": + user_content[0] = { + **user_content[0], + "text": f"**[{author_name}]:** {user_content[0].get('text', '')}", + } + langchain_messages.append(HumanMessage(content=user_content)) elif msg.role == "assistant": - langchain_messages.append(AIMessage(content=text_content)) + assistant_text = assistant_content_to_llm_text(msg.content) + if assistant_text: + langchain_messages.append(AIMessage(content=assistant_text)) return langchain_messages diff --git a/surfsense_backend/tests/unit/agents/chat/runtime/test_llm_config_sanitizer.py b/surfsense_backend/tests/unit/agents/chat/runtime/test_llm_config_sanitizer.py new file mode 100644 index 000000000..00d38b4b1 --- /dev/null +++ b/surfsense_backend/tests/unit/agents/chat/runtime/test_llm_config_sanitizer.py @@ -0,0 +1,40 @@ +"""Regression tests for model-boundary message sanitization.""" + +from __future__ import annotations + +import pytest +from langchain_core.messages import AIMessage + +from app.agents.chat.runtime.llm_config import _sanitize_messages + +pytestmark = pytest.mark.unit + + +def test_sanitize_messages_strips_provider_specific_thinking_blocks() -> None: + original = AIMessage( + content=[ + {"type": "thinking", "thinking": "private reasoning"}, + {"type": "text", "text": "visible answer"}, + ] + ) + + sanitized = _sanitize_messages([original]) + + assert sanitized[0].content == "visible answer" + assert original.content == [ + {"type": "thinking", "thinking": "private reasoning"}, + {"type": "text", "text": "visible answer"}, + ] + + +def test_sanitize_messages_sets_tool_only_ai_content_to_none() -> None: + message = AIMessage( + content="", + tool_calls=[{"name": "search", "args": {"q": "x"}, "id": "call_1"}], + ) + + sanitized = _sanitize_messages([message]) + + assert sanitized[0].content is None + assert message.content == "" + diff --git a/surfsense_backend/tests/unit/tasks/chat/test_llm_history_normalizer.py b/surfsense_backend/tests/unit/tasks/chat/test_llm_history_normalizer.py new file mode 100644 index 000000000..14bb030cd --- /dev/null +++ b/surfsense_backend/tests/unit/tasks/chat/test_llm_history_normalizer.py @@ -0,0 +1,62 @@ +"""Unit tests for provider-safe LLM history normalization.""" + +from __future__ import annotations + +import pytest + +from app.tasks.chat.llm_history_normalizer import ( + assistant_content_to_llm_text, + user_content_to_llm_content, +) + +pytestmark = pytest.mark.unit + + +def test_assistant_ui_parts_drop_thinking_steps_for_llm_history() -> None: + content = [ + {"type": "data-thinking-steps", "data": [{"id": "thinking-1"}]}, + {"type": "text", "text": "visible answer"}, + ] + + assert assistant_content_to_llm_text(content) == "visible answer" + + +def test_provider_thinking_blocks_are_not_replayed_to_llm() -> None: + content = [ + {"type": "thinking", "thinking": "private reasoning"}, + {"type": "text", "text": "final answer"}, + ] + + assert assistant_content_to_llm_text(content) == "final answer" + + +def test_unknown_assistant_blocks_are_dropped() -> None: + content = [ + {"type": "redacted_thinking", "data": "hidden"}, + {"type": "tool_use", "name": "search"}, + {"type": "text", "text": "kept"}, + ] + + assert assistant_content_to_llm_text(content) == "kept" + + +def test_user_images_convert_to_openai_compatible_image_url_blocks() -> None: + content = [ + {"type": "text", "text": "look"}, + {"type": "image", "image": "data:image/png;base64,abc"}, + ] + + assert user_content_to_llm_content(content, allow_images=True) == [ + {"type": "text", "text": "look"}, + {"type": "image_url", "image_url": {"url": "data:image/png;base64,abc"}}, + ] + + +def test_user_images_can_be_dropped_for_text_only_history() -> None: + content = [ + {"type": "text", "text": "look"}, + {"type": "image", "image": "data:image/png;base64,abc"}, + ] + + assert user_content_to_llm_content(content, allow_images=False) == "look" + diff --git a/surfsense_backend/tests/unit/tasks/chat/test_message_parts_normalizer.py b/surfsense_backend/tests/unit/tasks/chat/test_message_parts_normalizer.py new file mode 100644 index 000000000..173d03ed5 --- /dev/null +++ b/surfsense_backend/tests/unit/tasks/chat/test_message_parts_normalizer.py @@ -0,0 +1,68 @@ +"""Unit tests for final assistant message part normalization.""" + +from __future__ import annotations + +import pytest +from langchain_core.messages import AIMessage, HumanMessage, ToolMessage + +from app.tasks.chat.message_parts_normalizer import ( + final_assistant_parts_from_messages, + merge_streamed_and_final_parts, + normalize_ai_message_to_parts, +) + +pytestmark = pytest.mark.unit + + +def test_string_ai_message_content_becomes_text_part() -> None: + assert normalize_ai_message_to_parts(AIMessage(content="hello")) == [ + {"type": "text", "text": "hello"} + ] + + +def test_deepseek_thinking_plus_text_blocks_backfill_only_text() -> None: + message = AIMessage( + content=[ + {"type": "thinking", "thinking": "hidden reasoning"}, + {"type": "text", "text": "Yo bro! What's up?"}, + ], + additional_kwargs={"reasoning_content": "hidden reasoning"}, + ) + + assert normalize_ai_message_to_parts(message) == [ + {"type": "text", "text": "Yo bro! What's up?"} + ] + + +def test_final_parts_use_last_ai_message_and_skip_trailing_tool_messages() -> None: + messages = [ + HumanMessage(content="ask"), + AIMessage(content="draft"), + ToolMessage(content="tool output", tool_call_id="tc-1"), + AIMessage(content=[{"type": "text", "text": "final answer"}]), + ToolMessage(content="trailing tool noise", tool_call_id="tc-2"), + ] + + assert final_assistant_parts_from_messages(messages) == [ + {"type": "text", "text": "final answer"} + ] + + +def test_merge_adds_final_text_when_stream_only_has_thinking_steps() -> None: + streamed = [ + { + "type": "data-thinking-steps", + "data": [{"id": "thinking-1", "status": "completed"}], + } + ] + final = [{"type": "text", "text": "visible answer"}] + + assert merge_streamed_and_final_parts(streamed, final) == [*streamed, *final] + + +def test_merge_does_not_duplicate_when_stream_already_has_text() -> None: + streamed = [{"type": "text", "text": "streamed answer"}] + final = [{"type": "text", "text": "final answer"}] + + assert merge_streamed_and_final_parts(streamed, final) == streamed + From 610ff063d6afab5b573cf5aab8657771f3a7bca6 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Fri, 12 Jun 2026 02:17:51 +0530 Subject: [PATCH 024/212] refactor(model-connections): update frontend for provider-based models --- .../[search_space_id]/client-layout.tsx | 2 +- .../[search_space_id]/onboard/page.tsx | 2 +- .../model-connections-query.atoms.ts | 7 + .../components/assistant-ui/thread.tsx | 33 ++-- .../components/new-chat/model-selector.tsx | 13 +- .../settings/model-connections-settings.tsx | 182 ++++++++++-------- .../types/model-connections.types.ts | 44 +++-- .../hooks/use-automation-eligible-models.ts | 11 +- .../lib/apis/model-connections-api.service.ts | 6 + surfsense_web/lib/query-client/cache-keys.ts | 1 + 10 files changed, 177 insertions(+), 124 deletions(-) diff --git a/surfsense_web/app/dashboard/[search_space_id]/client-layout.tsx b/surfsense_web/app/dashboard/[search_space_id]/client-layout.tsx index c7e05fe99..2b16a038a 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/client-layout.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/client-layout.tsx @@ -49,7 +49,7 @@ export function DashboardClientLayout({ const firstGlobalChatModel = useMemo(() => { for (const connection of globalConnections) { - const model = connection.models.find((item) => item.enabled && item.capabilities?.chat); + const model = connection.models.find((item) => item.enabled && item.supports_chat); if (model) return model; } return null; diff --git a/surfsense_web/app/dashboard/[search_space_id]/onboard/page.tsx b/surfsense_web/app/dashboard/[search_space_id]/onboard/page.tsx index 9cf429a3a..c6fc1c7a2 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/onboard/page.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/onboard/page.tsx @@ -33,7 +33,7 @@ export default function OnboardPage() { const firstGlobalChatModel = useMemo(() => { for (const connection of globalConnections) { - const model = connection.models.find((item) => item.enabled && item.capabilities?.chat); + const model = connection.models.find((item) => item.enabled && item.supports_chat); if (model) return model; } return null; diff --git a/surfsense_web/atoms/model-connections/model-connections-query.atoms.ts b/surfsense_web/atoms/model-connections/model-connections-query.atoms.ts index 617ffe124..87f31ce9b 100644 --- a/surfsense_web/atoms/model-connections/model-connections-query.atoms.ts +++ b/surfsense_web/atoms/model-connections/model-connections-query.atoms.ts @@ -11,6 +11,13 @@ export const globalModelConnectionsAtom = atomWithQuery(() => ({ queryFn: () => modelConnectionsApiService.getGlobalConnections(), })); +export const modelProvidersAtom = atomWithQuery(() => ({ + queryKey: cacheKeys.modelConnections.providers(), + enabled: !!getBearerToken(), + staleTime: 60 * 60 * 1000, + queryFn: () => modelConnectionsApiService.getModelProviders(), +})); + export const modelConnectionsAtom = atomWithQuery((get) => { const searchSpaceId = Number(get(activeSearchSpaceIdAtom)); return { diff --git a/surfsense_web/components/assistant-ui/thread.tsx b/surfsense_web/components/assistant-ui/thread.tsx index 5796109f0..722ebb476 100644 --- a/surfsense_web/components/assistant-ui/thread.tsx +++ b/surfsense_web/components/assistant-ui/thread.tsx @@ -48,10 +48,10 @@ import { connectorDialogOpenAtom } from "@/atoms/connector-dialog/connector-dial import { connectorsAtom } from "@/atoms/connectors/connector-query.atoms"; import { membersAtom } from "@/atoms/members/members-query.atoms"; import { - globalNewLLMConfigsAtom, - llmPreferencesAtom, - newLLMConfigsAtom, -} from "@/atoms/new-llm-config/new-llm-config-query.atoms"; + globalModelConnectionsAtom, + modelConnectionsAtom, + modelRolesAtom, +} from "@/atoms/model-connections/model-connections-query.atoms"; import { currentUserAtom } from "@/atoms/user/user-query.atoms"; import { AssistantMessage } from "@/components/assistant-ui/assistant-message"; import { ChatSessionStatus } from "@/components/assistant-ui/chat-session-status"; @@ -976,9 +976,9 @@ const ComposerAction: FC = ({ isBlockedByOtherUser = false if (url) setPendingScreenImages((prev) => [...prev, url]); }, [electronAPI, setPendingScreenImages]); - const { data: userConfigs } = useAtomValue(newLLMConfigsAtom); - const { data: globalConfigs } = useAtomValue(globalNewLLMConfigsAtom); - const { data: preferences } = useAtomValue(llmPreferencesAtom); + const { data: globalModelConnections } = useAtomValue(globalModelConnectionsAtom); + const { data: modelConnections } = useAtomValue(modelConnectionsAtom); + const { data: modelRoles } = useAtomValue(modelRolesAtom); const { data: agentTools } = useAtomValue(agentToolsAtom); const disabledTools = useAtomValue(disabledToolsAtom); @@ -1065,15 +1065,18 @@ const ComposerAction: FC = ({ isBlockedByOtherUser = false }, [hydrateDisabled]); const hasModelConfigured = useMemo(() => { - if (!preferences) return false; - const agentLlmId = preferences.agent_llm_id; - if (agentLlmId === null || agentLlmId === undefined) return false; - - if (agentLlmId <= 0) { - return globalConfigs?.some((c) => c.id === agentLlmId) ?? false; + const chatModelId = modelRoles?.chat_model_id ?? 0; + if (chatModelId === 0) { + return [...(globalModelConnections ?? []), ...(modelConnections ?? [])].some((connection) => + connection.models.some((model) => model.enabled && Boolean(model.supports_chat)) + ); } - return userConfigs?.some((c) => c.id === agentLlmId) ?? false; - }, [preferences, globalConfigs, userConfigs]); + return [...(globalModelConnections ?? []), ...(modelConnections ?? [])].some((connection) => + connection.models.some( + (model) => model.id === chatModelId && model.enabled && Boolean(model.supports_chat) + ) + ); + }, [modelRoles?.chat_model_id, globalModelConnections, modelConnections]); const isSendDisabled = isComposerEmpty || !hasModelConfigured || isBlockedByOtherUser; diff --git a/surfsense_web/components/new-chat/model-selector.tsx b/surfsense_web/components/new-chat/model-selector.tsx index 4744da617..6850096d6 100644 --- a/surfsense_web/components/new-chat/model-selector.tsx +++ b/surfsense_web/components/new-chat/model-selector.tsx @@ -56,18 +56,18 @@ function modelName(model: ModelRead) { function connectionLabel(connection: ConnectionRead) { if (connection.scope === "GLOBAL") return "Hosted"; - return connection.litellm_provider || connection.protocol; + return connection.provider; } function flattenChatModels(connections: ConnectionRead[]) { return connections.flatMap((connection) => connection.models - .filter((model) => model.enabled && Boolean(model.capabilities?.chat)) + .filter((model) => model.enabled && Boolean(model.supports_chat)) .map((model) => ({ ...model, connectionId: connection.id, connectionLabel: connectionLabel(connection), - provider: connection.litellm_provider || connection.protocol, + provider: connection.provider, })) ); } @@ -184,9 +184,14 @@ export function ModelSelector({ {modelName(model)}
{model.model_id}
+ {model.max_input_tokens ? ( +
+ {model.max_input_tokens.toLocaleString()} context +
+ ) : null}
- {!model.capabilities?.vision ? ( + {!model.supports_image_input ? ( No image diff --git a/surfsense_web/components/settings/model-connections-settings.tsx b/surfsense_web/components/settings/model-connections-settings.tsx index 29501abda..0e541548b 100644 --- a/surfsense_web/components/settings/model-connections-settings.tsx +++ b/surfsense_web/components/settings/model-connections-settings.tsx @@ -1,11 +1,12 @@ "use client"; import { useAtom, useAtomValue } from "jotai"; -import { CheckCircle2, PlugZap, Plus, RefreshCcw, XCircle } from "lucide-react"; +import { CheckCircle2, PlugZap, Plus, RefreshCcw, Trash2, XCircle } from "lucide-react"; import { useState } from "react"; import { addManualModelMutationAtom, createModelConnectionMutationAtom, + deleteModelConnectionMutationAtom, discoverConnectionModelsMutationAtom, testModelMutationAtom, updateModelConnectionMutationAtom, @@ -16,6 +17,7 @@ import { import { globalModelConnectionsAtom, modelConnectionsAtom, + modelProvidersAtom, modelRolesAtom, } from "@/atoms/model-connections/model-connections-query.atoms"; import { Badge } from "@/components/ui/badge"; @@ -30,37 +32,9 @@ import { SelectTrigger, SelectValue, } from "@/components/ui/select"; -import type { - ConnectionProtocol, - ConnectionRead, - ModelRead, -} from "@/contracts/types/model-connections.types"; +import type { ConnectionRead, ModelRead } from "@/contracts/types/model-connections.types"; import { getProviderIcon } from "@/lib/provider-icons"; -const PROTOCOL_OPTIONS: { value: ConnectionProtocol; label: string; description: string }[] = [ - { - value: "OPENAI_COMPATIBLE", - label: "OpenAI-compatible", - description: "Use for OpenAI, OpenRouter, Groq, vLLM, LM Studio, and compatible APIs.", - }, - { - value: "ANTHROPIC", - label: "Anthropic", - description: "Use for Claude endpoints that require Anthropic headers.", - }, - { - value: "OLLAMA", - label: "Ollama", - description: "Use for Ollama's native API.", - }, -]; - -function defaultLitellmProvider(protocol: ConnectionProtocol) { - if (protocol === "OLLAMA") return "ollama_chat"; - if (protocol === "ANTHROPIC") return "anthropic"; - return "openai"; -} - // Free-text URL hints (datalist), mirroring OpenWebUI. These never restrict // what the user can type — any OpenAI-compatible endpoint works. const URL_SUGGESTIONS = [ @@ -82,9 +56,19 @@ function modelLabel(model: ModelRead) { } function capability(model: ModelRead, key: "chat" | "vision" | "image_gen") { - return Boolean(model.capabilities?.[key]); + if (key === "chat") return Boolean(model.supports_chat); + if (key === "vision") return Boolean(model.supports_image_input); + return Boolean(model.supports_image_generation); } +type ModelCapabilityFilter = "chat" | "vision" | "image_gen"; + +const MODEL_CAPABILITY_FILTERS: { key: ModelCapabilityFilter; label: string }[] = [ + { key: "chat", label: "Chat" }, + { key: "vision", label: "Vision" }, + { key: "image_gen", label: "Image" }, +]; + function StatusBadge({ connection }: { connection: ConnectionRead }) { if (connection.last_status === "OK") { return ( @@ -107,9 +91,9 @@ function flattenModels(connections: ConnectionRead[]) { return connections.flatMap((connection) => connection.models.map((model) => ({ ...model, - connectionName: connection.litellm_provider || connection.protocol, + connectionName: connection.provider, connectionId: connection.id, - provider: connection.litellm_provider || connection.protocol, + provider: connection.provider, })) ); } @@ -118,6 +102,7 @@ function ConnectionCard({ connection }: { connection: ConnectionRead }) { const verifyConnection = useAtomValue(verifyModelConnectionMutationAtom); const discoverModels = useAtomValue(discoverConnectionModelsMutationAtom); const updateConnection = useAtomValue(updateModelConnectionMutationAtom); + const deleteConnection = useAtomValue(deleteModelConnectionMutationAtom); const addManualModel = useAtomValue(addManualModelMutationAtom); const updateModel = useAtomValue(updateModelMutationAtom); const testModel = useAtomValue(testModelMutationAtom); @@ -127,9 +112,16 @@ function ConnectionCard({ connection }: { connection: ConnectionRead }) { : []; const [allowlistText, setAllowlistText] = useState(allowlist.join(", ")); const [manualModelId, setManualModelId] = useState(""); + const [modelFilter, setModelFilter] = useState(null); - const providerLabel = connection.litellm_provider || connection.protocol; - const isLocal = connection.protocol === "OLLAMA" || !connection.base_url?.startsWith("https"); + const providerLabel = connection.provider; + const isLocal = + connection.provider === "ollama_chat" || + connection.provider === "lm_studio" || + !connection.base_url?.startsWith("https"); + const filteredModels = modelFilter + ? connection.models.filter((model) => capability(model, modelFilter)) + : connection.models; function saveAllowlist() { const ids = allowlistText @@ -151,6 +143,14 @@ function ConnectionCard({ connection }: { connection: ConnectionRead }) { ); } + function deleteCurrentConnection() { + const confirmed = window.confirm( + `Delete the ${providerLabel} connection and all of its models? This cannot be undone.` + ); + if (!confirmed) return; + deleteConnection.mutate(connection.id); + } + return (
@@ -175,6 +175,14 @@ function ConnectionCard({ connection }: { connection: ConnectionRead }) { +
@@ -232,8 +240,38 @@ function ConnectionCard({ connection }: { connection: ConnectionRead }) {
+ {connection.models.length > 0 ? ( +
+ Filter models + {MODEL_CAPABILITY_FILTERS.map((filter) => { + const count = connection.models.filter((model) => capability(model, filter.key)).length; + const isActive = modelFilter === filter.key; + + return ( + + ); + })} +
+ ) : null} +
- {connection.models.map((model) => ( + {filteredModels.length === 0 && modelFilter ? ( +
+ No {MODEL_CAPABILITY_FILTERS.find((filter) => filter.key === modelFilter)?.label.toLowerCase()}{" "} + models found on this connection. +
+ ) : null} + {filteredModels.map((model) => (
{["chat", "vision", "image_gen"] - .filter((key) => Boolean(model.capabilities?.[key])) - .join(", ") || "No verified capabilities"} + .filter((key) => capability(model, key as "chat" | "vision" | "image_gen")) + .join(", ") || "No discovered capabilities"}
@@ -278,18 +316,16 @@ function ConnectionCard({ connection }: { connection: ConnectionRead }) { export function ModelConnectionsSettings({ searchSpaceId }: { searchSpaceId: number }) { const [{ data: globalConnections = [] }] = useAtom(globalModelConnectionsAtom); const [{ data: connections = [] }] = useAtom(modelConnectionsAtom); + const [{ data: providers = [] }] = useAtom(modelProvidersAtom); const [{ data: roles }] = useAtom(modelRolesAtom); const createConnection = useAtomValue(createModelConnectionMutationAtom); const updateRoles = useAtomValue(updateModelRolesMutationAtom); - const [protocol, setProtocol] = useState("OPENAI_COMPATIBLE"); + const [provider, setProvider] = useState("openai_compatible"); const [baseUrl, setBaseUrl] = useState(""); const [apiKey, setApiKey] = useState(""); - const [litellmProvider, setLitellmProvider] = useState(""); - const [showAdvancedProvider, setShowAdvancedProvider] = useState(false); - const selectedProtocol = PROTOCOL_OPTIONS.find((item) => item.value === protocol); - const protocolDefaultProvider = defaultLitellmProvider(protocol); - const isOllama = protocol === "OLLAMA"; + const selectedProvider = providers.find((item) => item.provider === provider); + const isOllama = provider === "ollama_chat"; const allConnections = [...globalConnections, ...connections]; const enabledModels = flattenModels(allConnections).filter((model) => model.enabled); @@ -298,11 +334,9 @@ export function ModelConnectionsSettings({ searchSpaceId }: { searchSpaceId: num const imageModels = enabledModels.filter((model) => capability(model, "image_gen")); function handleCreate() { - const explicitProvider = litellmProvider.trim(); createConnection.mutate( { - protocol, - litellm_provider: explicitProvider ? explicitProvider : null, + provider, base_url: baseUrl || null, api_key: apiKey || null, scope: "SEARCH_SPACE", @@ -337,18 +371,22 @@ export function ModelConnectionsSettings({ searchSpaceId }: { searchSpaceId: num
- + setLitellmProvider(event.target.value)} - placeholder={protocolDefaultProvider} - /> -

- Leave empty to use the protocol default. Set this for more accurate LiteLLM - capabilities/costs, for example openrouter, groq, gemini, or azure. -

-
- ) : null} -
diff --git a/surfsense_web/contracts/types/model-connections.types.ts b/surfsense_web/contracts/types/model-connections.types.ts index 7a37799c4..a34687d74 100644 --- a/surfsense_web/contracts/types/model-connections.types.ts +++ b/surfsense_web/contracts/types/model-connections.types.ts @@ -1,26 +1,19 @@ import { z } from "zod"; -export const connectionProtocolEnum = z.enum(["OLLAMA", "OPENAI_COMPATIBLE", "ANTHROPIC"]); export const connectionScopeEnum = z.enum(["GLOBAL", "SEARCH_SPACE", "USER"]); export const modelSourceEnum = z.enum(["DISCOVERED", "MANUAL"]); -export const modelCapabilities = z.object({ - chat: z.boolean().optional(), - vision: z.boolean().optional(), - image_gen: z.boolean().optional(), - embedding: z.boolean().optional(), - tools: z.boolean().optional(), -}); - export const modelRead = z.object({ id: z.number(), connection_id: z.number(), model_id: z.string(), display_name: z.string().nullable().optional(), source: z.union([modelSourceEnum, z.string()]), - capabilities: z.record(z.string(), z.any()).default({}), - capabilities_declared: z.record(z.string(), z.any()).default({}), - capabilities_verified: z.record(z.string(), z.any()).default({}), + supports_chat: z.boolean().nullable().optional(), + max_input_tokens: z.number().nullable().optional(), + supports_image_input: z.boolean().nullable().optional(), + supports_tools: z.boolean().nullable().optional(), + supports_image_generation: z.boolean().nullable().optional(), capabilities_override: z.record(z.string(), z.any()).default({}), embedding_dimension: z.number().nullable().optional(), enabled: z.boolean(), @@ -31,8 +24,7 @@ export const modelRead = z.object({ export const connectionRead = z.object({ id: z.number(), - protocol: z.union([connectionProtocolEnum, z.string()]), - litellm_provider: z.string().nullable().optional(), + provider: z.string(), base_url: z.string().nullable().optional(), extra: z.record(z.string(), z.any()).default({}), scope: z.union([connectionScopeEnum, z.string()]), @@ -48,8 +40,7 @@ export const connectionRead = z.object({ }); export const connectionCreateRequest = z.object({ - protocol: connectionProtocolEnum, - litellm_provider: z.string().nullable().optional(), + provider: z.string().min(1), base_url: z.string().nullable().optional(), api_key: z.string().nullable().optional(), extra: z.record(z.string(), z.any()).default({}), @@ -59,7 +50,7 @@ export const connectionCreateRequest = z.object({ }); export const connectionUpdateRequest = z.object({ - litellm_provider: z.string().nullable().optional(), + provider: z.string().nullable().optional(), base_url: z.string().nullable().optional(), api_key: z.string().nullable().optional(), extra: z.record(z.string(), z.any()).optional(), @@ -74,6 +65,11 @@ export const modelCreateRequest = z.object({ export const modelUpdateRequest = z.object({ display_name: z.string().nullable().optional(), enabled: z.boolean().optional(), + supports_chat: z.boolean().nullable().optional(), + max_input_tokens: z.number().nullable().optional(), + supports_image_input: z.boolean().nullable().optional(), + supports_tools: z.boolean().nullable().optional(), + supports_image_generation: z.boolean().nullable().optional(), capabilities_override: z.record(z.string(), z.any()).optional(), }); @@ -89,10 +85,21 @@ export const modelRoles = z.object({ image_gen_model_id: z.number().nullable().optional(), }); +export const modelProviderRead = z.object({ + provider: z.string(), + transport: z.string(), + discovery: z.string(), + default_base_url: z.string().nullable().optional(), + base_url_required: z.boolean(), + auth_style: z.string(), + local_only: z.boolean().default(false), +}); + +export const modelProviderListResponse = z.array(modelProviderRead); + export const connectionListResponse = z.array(connectionRead); export const modelListResponse = z.array(modelRead); -export type ConnectionProtocol = z.infer; export type ConnectionScope = z.infer; export type ModelRead = z.infer; export type ConnectionRead = z.infer; @@ -102,3 +109,4 @@ export type ModelCreateRequest = z.infer; export type ModelUpdateRequest = z.infer; export type ModelRoles = z.infer; export type VerifyConnectionResponse = z.infer; +export type ModelProviderRead = z.infer; diff --git a/surfsense_web/hooks/use-automation-eligible-models.ts b/surfsense_web/hooks/use-automation-eligible-models.ts index f8b264162..fd3ad3a6a 100644 --- a/surfsense_web/hooks/use-automation-eligible-models.ts +++ b/surfsense_web/hooks/use-automation-eligible-models.ts @@ -47,11 +47,16 @@ function buildKind( capability: "chat" | "image_gen" | "vision", prefId: number | null | undefined ): EligibleModelKind { + const supportsCapability = (model: ModelRead) => { + if (capability === "chat") return Boolean(model.supports_chat); + if (capability === "vision") return Boolean(model.supports_image_input); + return Boolean(model.supports_image_generation); + }; const toOption = (connection: ConnectionRead, model: ModelRead, isBYOK: boolean) => ({ id: model.id, name: model.display_name || model.model_id, modelName: model.model_id, - provider: connection.litellm_provider || connection.protocol, + provider: connection.provider, isBYOK, }); @@ -60,7 +65,7 @@ function buildKind( .filter( (model) => model.enabled && - Boolean(model.capabilities?.[capability]) && + supportsCapability(model) && String(model.billing_tier ?? "").toLowerCase() === "premium" ) .map((model) => toOption(connection, model, false)) @@ -68,7 +73,7 @@ function buildKind( const byokOptions: EligibleModelOption[] = (byok ?? []).flatMap((connection) => connection.models - .filter((model) => model.enabled && Boolean(model.capabilities?.[capability])) + .filter((model) => model.enabled && supportsCapability(model)) .map((model) => toOption(connection, model, true)) ); diff --git a/surfsense_web/lib/apis/model-connections-api.service.ts b/surfsense_web/lib/apis/model-connections-api.service.ts index 92eec8e61..12ad8e0d2 100644 --- a/surfsense_web/lib/apis/model-connections-api.service.ts +++ b/surfsense_web/lib/apis/model-connections-api.service.ts @@ -7,10 +7,12 @@ import { connectionRead, connectionUpdateRequest, type ModelCreateRequest, + type ModelProviderRead, type ModelRead, type ModelRoles, type ModelUpdateRequest, modelCreateRequest, + modelProviderListResponse, modelListResponse, modelRead, modelRoles, @@ -26,6 +28,10 @@ class ModelConnectionsApiService { return baseApiService.get(`/api/v1/global-model-connections`, connectionListResponse); }; + getModelProviders = async (): Promise => { + return baseApiService.get(`/api/v1/model-providers`, modelProviderListResponse); + }; + getConnections = async (searchSpaceId: number): Promise => { return baseApiService.get( `/api/v1/model-connections?search_space_id=${searchSpaceId}`, diff --git a/surfsense_web/lib/query-client/cache-keys.ts b/surfsense_web/lib/query-client/cache-keys.ts index 558a73f95..5a3f0fb84 100644 --- a/surfsense_web/lib/query-client/cache-keys.ts +++ b/surfsense_web/lib/query-client/cache-keys.ts @@ -47,6 +47,7 @@ export const cacheKeys = { modelConnections: { all: (searchSpaceId: number) => ["model-connections", searchSpaceId] as const, global: () => ["model-connections", "global"] as const, + providers: () => ["model-connections", "providers"] as const, roles: (searchSpaceId: number) => ["model-roles", searchSpaceId] as const, }, imageGenConfigs: { From 5da3ab0552b6a03e2fcc0e5e7527b9210bd9424f Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Fri, 12 Jun 2026 03:20:09 +0530 Subject: [PATCH 025/212] feat(database): rename add_model_connections alembic migration --- ..._model_connections.py => 160_add_model_connections.py} | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) rename surfsense_backend/alembic/versions/{156_add_model_connections.py => 160_add_model_connections.py} (99%) diff --git a/surfsense_backend/alembic/versions/156_add_model_connections.py b/surfsense_backend/alembic/versions/160_add_model_connections.py similarity index 99% rename from surfsense_backend/alembic/versions/156_add_model_connections.py rename to surfsense_backend/alembic/versions/160_add_model_connections.py index 64614db99..49d6315ca 100644 --- a/surfsense_backend/alembic/versions/156_add_model_connections.py +++ b/surfsense_backend/alembic/versions/160_add_model_connections.py @@ -1,7 +1,7 @@ """add model connections -Revision ID: 156 -Revises: 155 +Revision ID: 160 +Revises: 159 """ from collections.abc import Sequence @@ -11,8 +11,8 @@ from sqlalchemy.dialects import postgresql from alembic import op -revision: str = "156" -down_revision: str | None = "155" +revision: str = "160" +down_revision: str | None = "159" branch_labels: str | Sequence[str] | None = None depends_on: str | Sequence[str] | None = None From aba95e4faf1ed277eb0b27ba9e4b649e976113f2 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Fri, 12 Jun 2026 03:35:49 +0530 Subject: [PATCH 026/212] feat(database): enhance podcast lifecycle management by adding temporary unpublishing during migration --- .../versions/158_evolve_podcasts_lifecycle.py | 28 ++++++++++ surfsense_backend/app/zero_publication.py | 51 ++++++++++++++++--- 2 files changed, 72 insertions(+), 7 deletions(-) diff --git a/surfsense_backend/alembic/versions/158_evolve_podcasts_lifecycle.py b/surfsense_backend/alembic/versions/158_evolve_podcasts_lifecycle.py index 15cf04f9d..7c51158a9 100644 --- a/surfsense_backend/alembic/versions/158_evolve_podcasts_lifecycle.py +++ b/surfsense_backend/alembic/versions/158_evolve_podcasts_lifecycle.py @@ -13,8 +13,34 @@ down_revision: str | None = "157" branch_labels: str | Sequence[str] | None = None depends_on: str | Sequence[str] | None = None +PUBLICATION_NAME = "zero_publication" + + +def _drop_podcasts_from_zero_publication() -> None: + """Temporarily unpublish podcasts while changing published columns.""" + + op.execute( + f""" + DO $$ + BEGIN + IF EXISTS ( + SELECT 1 + FROM pg_publication_tables + WHERE pubname = '{PUBLICATION_NAME}' + AND schemaname = current_schema() + AND tablename = 'podcasts' + ) THEN + ALTER PUBLICATION "{PUBLICATION_NAME}" DROP TABLE "podcasts"; + END IF; + END + $$; + """ + ) + def upgrade() -> None: + _drop_podcasts_from_zero_publication() + # Retype the status enum by swapping in a fresh type and casting existing # rows. The legacy transient value 'generating' maps onto 'rendering'. op.execute("ALTER TYPE podcast_status RENAME TO podcast_status_old;") @@ -57,6 +83,8 @@ def upgrade() -> None: def downgrade() -> None: + _drop_podcasts_from_zero_publication() + op.execute("ALTER TABLE podcasts DROP COLUMN IF EXISTS error;") op.execute("ALTER TABLE podcasts DROP COLUMN IF EXISTS duration_seconds;") op.execute("ALTER TABLE podcasts DROP COLUMN IF EXISTS storage_key;") diff --git a/surfsense_backend/app/zero_publication.py b/surfsense_backend/app/zero_publication.py index 139286ee6..869559c55 100644 --- a/surfsense_backend/app/zero_publication.py +++ b/surfsense_backend/app/zero_publication.py @@ -100,12 +100,32 @@ def _column_exists(conn: Connection, table: str, column: str) -> bool: ) -def _expected_columns(conn: Connection, table: str) -> list[str] | None: +def _table_exists(conn: Connection, table: str) -> bool: + return ( + conn.execute( + text( + "SELECT 1 FROM information_schema.tables " + "WHERE table_schema = current_schema() " + "AND table_name = :table" + ), + {"table": table}, + ).fetchone() + is not None + ) + + +def _expected_columns( + conn: Connection, table: str, *, include_missing_columns: bool = True +) -> list[str] | None: columns = ZERO_PUBLICATION[table] if columns is None: return None - expected = list(columns) + if include_missing_columns: + expected = list(columns) + else: + expected = [column for column in columns if _column_exists(conn, table, column)] + if table in {"documents", "user", "podcasts"} and _column_exists( conn, table, "_0_version" ): @@ -113,11 +133,20 @@ def _expected_columns(conn: Connection, table: str) -> list[str] | None: return expected -def _format_table_entry(conn: Connection, table: str) -> str: - columns = _expected_columns(conn, table) +def _format_table_entry( + conn: Connection, table: str, *, include_missing_columns: bool = True +) -> str | None: + if not include_missing_columns and not _table_exists(conn, table): + return None + + columns = _expected_columns( + conn, table, include_missing_columns=include_missing_columns + ) table_sql = _quote_identifier(table) if columns is None: return table_sql + if not include_missing_columns and not columns: + return None column_sql = ", ".join(_quote_identifier(column) for column in columns) return f"{table_sql} ({column_sql})" @@ -126,9 +155,17 @@ def _format_table_entry(conn: Connection, table: str) -> str: def build_set_table_sql(conn: Connection) -> str: """Build the canonical plain SET TABLE statement for Zero's event triggers.""" - table_list = ", ".join( - _format_table_entry(conn, table) for table in ZERO_PUBLICATION - ) + table_entries = [ + entry + for table in ZERO_PUBLICATION + if ( + entry := _format_table_entry( + conn, table, include_missing_columns=False + ) + ) + is not None + ] + table_list = ", ".join(table_entries) return f"ALTER PUBLICATION {_quote_identifier(PUBLICATION_NAME)} SET TABLE {table_list}" From 8e8cf96faa629e7f86ef60ed2f12ed61012b2a4f Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Fri, 12 Jun 2026 05:03:14 +0530 Subject: [PATCH 027/212] feat(error-handling): implement LLM error adaptation and classification for chat streaming - Introduced LLMErrorCategory and adapt_llm_exception to normalize LLM exceptions. - Updated llm_retryable_message and llm_permanent_message to utilize the new adaptation logic. - Enhanced classify_stream_exception to classify provider errors and return user-friendly messages. - Added tests for error classification and adaptation to ensure robustness. - Updated frontend error handling to display appropriate messages based on new classifications. --- .../app/indexing_pipeline/exceptions.py | 36 +-- .../app/routes/anonymous_chat_routes.py | 11 +- .../app/services/llm_error_adapter.py | 251 ++++++++++++++++++ .../tasks/chat/streaming/errors/classifier.py | 94 ++++++- .../chat/streaming/test_error_classifier.py | 80 ++++++ .../new-chat/[[...chat_id]]/page.tsx | 12 + .../components/free-chat/free-chat-page.tsx | 17 +- .../lib/chat/chat-error-classifier.ts | 66 ++++- surfsense_web/lib/chat/chat-request-errors.ts | 4 + 9 files changed, 533 insertions(+), 38 deletions(-) create mode 100644 surfsense_backend/app/services/llm_error_adapter.py create mode 100644 surfsense_backend/tests/unit/tasks/chat/streaming/test_error_classifier.py diff --git a/surfsense_backend/app/indexing_pipeline/exceptions.py b/surfsense_backend/app/indexing_pipeline/exceptions.py index 666fa4b9f..bf9d9e9fa 100644 --- a/surfsense_backend/app/indexing_pipeline/exceptions.py +++ b/surfsense_backend/app/indexing_pipeline/exceptions.py @@ -14,6 +14,8 @@ from litellm.exceptions import ( ) from sqlalchemy.exc import IntegrityError as IntegrityError +from app.services.llm_error_adapter import LLMErrorCategory, adapt_llm_exception + # Tuples for use directly in except clauses. RETRYABLE_LLM_ERRORS = ( RateLimitError, @@ -97,38 +99,20 @@ def safe_exception_message(exc: Exception) -> str: def llm_retryable_message(exc: Exception) -> str: try: - if isinstance(exc, RateLimitError): - return PipelineMessages.RATE_LIMIT - if isinstance(exc, Timeout): - return PipelineMessages.LLM_TIMEOUT - if isinstance(exc, ServiceUnavailableError): - return PipelineMessages.LLM_UNAVAILABLE - if isinstance(exc, BadGatewayError): - return PipelineMessages.LLM_BAD_GATEWAY - if isinstance(exc, InternalServerError): - return PipelineMessages.LLM_SERVER_ERROR - if isinstance(exc, APIConnectionError): - return PipelineMessages.LLM_CONNECTION - return safe_exception_message(exc) + adapted = adapt_llm_exception(exc) + if adapted.category is LLMErrorCategory.UNKNOWN: + return safe_exception_message(exc) + return adapted.user_message except Exception: return "Something went wrong when calling the LLM." def llm_permanent_message(exc: Exception) -> str: try: - if isinstance(exc, AuthenticationError): - return PipelineMessages.LLM_AUTH - if isinstance(exc, PermissionDeniedError): - return PipelineMessages.LLM_PERMISSION - if isinstance(exc, NotFoundError): - return PipelineMessages.LLM_NOT_FOUND - if isinstance(exc, BadRequestError): - return PipelineMessages.LLM_BAD_REQUEST - if isinstance(exc, UnprocessableEntityError): - return PipelineMessages.LLM_UNPROCESSABLE - if isinstance(exc, APIResponseValidationError): - return PipelineMessages.LLM_RESPONSE - return safe_exception_message(exc) + adapted = adapt_llm_exception(exc) + if adapted.category is LLMErrorCategory.UNKNOWN: + return safe_exception_message(exc) + return adapted.user_message except Exception: return "Something went wrong when calling the LLM." diff --git a/surfsense_backend/app/routes/anonymous_chat_routes.py b/surfsense_backend/app/routes/anonymous_chat_routes.py index 84420e738..aa0e70464 100644 --- a/surfsense_backend/app/routes/anonymous_chat_routes.py +++ b/surfsense_backend/app/routes/anonymous_chat_routes.py @@ -18,6 +18,7 @@ from app.etl_pipeline.file_classifier import ( PLAINTEXT_EXTENSIONS, ) from app.rate_limiter import limiter +from app.tasks.chat.streaming.errors.classifier import classify_stream_exception logger = logging.getLogger(__name__) @@ -474,7 +475,15 @@ async def stream_anonymous_chat( except Exception as e: logger.exception("Anonymous chat stream error") await TokenQuotaService.anon_release(session_key, ip_key, request_id) - yield streaming_service.format_error(f"Error during chat: {e!s}") + _, error_code, _, _, user_message, extra = classify_stream_exception( + e, + flow_label="chat", + ) + yield streaming_service.format_error( + user_message, + error_code=error_code, + extra=extra, + ) yield streaming_service.format_done() finally: await TokenQuotaService.anon_release_stream_slot(client_ip) diff --git a/surfsense_backend/app/services/llm_error_adapter.py b/surfsense_backend/app/services/llm_error_adapter.py new file mode 100644 index 000000000..b0de15fb0 --- /dev/null +++ b/surfsense_backend/app/services/llm_error_adapter.py @@ -0,0 +1,251 @@ +"""Normalize provider/LLM exceptions into low-cardinality product categories.""" + +from __future__ import annotations + +import json +from dataclasses import dataclass +from enum import StrEnum +from typing import Any + + +class LLMErrorCategory(StrEnum): + RATE_LIMITED = "rate_limited" + TIMEOUT = "timeout" + PROVIDER_UNAVAILABLE = "provider_unavailable" + BAD_GATEWAY = "bad_gateway" + CONNECTION_FAILED = "connection_failed" + AUTH_FAILED = "auth_failed" + PERMISSION_DENIED = "permission_denied" + MODEL_NOT_FOUND = "model_not_found" + BAD_REQUEST = "bad_request" + CONTEXT_LIMIT = "context_limit" + RESPONSE_INVALID = "response_invalid" + SERVER_ERROR = "server_error" + UNKNOWN = "unknown" + + +@dataclass(frozen=True) +class LLMErrorAdaptation: + category: LLMErrorCategory + retryable: bool + user_message: str + provider_status_code: int | None = None + provider_error_type: str | None = None + + +_CATEGORY_MESSAGES: dict[LLMErrorCategory, str] = { + LLMErrorCategory.RATE_LIMITED: "LLM rate limit exceeded. Will retry on next sync.", + LLMErrorCategory.TIMEOUT: "LLM request timed out. Will retry on next sync.", + LLMErrorCategory.PROVIDER_UNAVAILABLE: "LLM service temporarily unavailable. Will retry on next sync.", + LLMErrorCategory.BAD_GATEWAY: "LLM gateway error. Will retry on next sync.", + LLMErrorCategory.CONNECTION_FAILED: "Could not reach the LLM service. Check network connectivity.", + LLMErrorCategory.AUTH_FAILED: "LLM authentication failed. Check your API key.", + LLMErrorCategory.PERMISSION_DENIED: "LLM request denied. Check your account permissions.", + LLMErrorCategory.MODEL_NOT_FOUND: "Model not found. Check your model configuration.", + LLMErrorCategory.BAD_REQUEST: "LLM rejected the request. Document content may be invalid.", + LLMErrorCategory.CONTEXT_LIMIT: "Document exceeds the LLM context window even after optimization.", + LLMErrorCategory.RESPONSE_INVALID: "LLM returned an invalid response.", + LLMErrorCategory.SERVER_ERROR: "LLM internal server error. Will retry on next sync.", + LLMErrorCategory.UNKNOWN: "Something went wrong when calling the LLM.", +} + +_RETRYABLE_CATEGORIES = { + LLMErrorCategory.RATE_LIMITED, + LLMErrorCategory.TIMEOUT, + LLMErrorCategory.PROVIDER_UNAVAILABLE, + LLMErrorCategory.BAD_GATEWAY, + LLMErrorCategory.CONNECTION_FAILED, + LLMErrorCategory.SERVER_ERROR, +} + +_CLASS_NAME_MAP: tuple[tuple[LLMErrorCategory, tuple[str, ...]], ...] = ( + ( + LLMErrorCategory.RATE_LIMITED, + ("RateLimitError", "TooManyRequests", "TooManyRequestsError"), + ), + (LLMErrorCategory.TIMEOUT, ("Timeout", "APITimeoutError", "TimeoutException")), + ( + LLMErrorCategory.PROVIDER_UNAVAILABLE, + ("ServiceUnavailableError", "ServiceUnavailable"), + ), + ( + LLMErrorCategory.BAD_GATEWAY, + ("BadGatewayError", "GatewayTimeoutError"), + ), + ( + LLMErrorCategory.CONNECTION_FAILED, + ("APIConnectionError", "ConnectError", "ConnectTimeout", "ReadTimeout"), + ), + ( + LLMErrorCategory.AUTH_FAILED, + ("AuthenticationError", "InvalidApiKey", "InvalidAPIKey", "InvalidApiKeyError"), + ), + (LLMErrorCategory.PERMISSION_DENIED, ("PermissionDeniedError", "ForbiddenError")), + (LLMErrorCategory.MODEL_NOT_FOUND, ("NotFoundError", "ModelNotFoundError")), + ( + LLMErrorCategory.CONTEXT_LIMIT, + ("ContextWindowExceeded", "ContextOverflow", "ContextLimit"), + ), + ( + LLMErrorCategory.RESPONSE_INVALID, + ("APIResponseValidationError", "ResponseValidationError"), + ), + ( + LLMErrorCategory.BAD_REQUEST, + ("BadRequestError", "InvalidRequestError", "UnprocessableEntityError"), + ), + (LLMErrorCategory.SERVER_ERROR, ("InternalServerError",)), +) + + +def _parse_error_payload(message: str) -> dict[str, Any] | None: + candidates = [message] + first_brace_idx = message.find("{") + if first_brace_idx >= 0: + candidates.append(message[first_brace_idx:]) + + for candidate in candidates: + try: + parsed = json.loads(candidate) + if isinstance(parsed, dict): + return parsed + except Exception: + continue + return None + + +def _class_names(exc: BaseException) -> tuple[str, ...]: + return tuple(cls.__name__ for cls in type(exc).__mro__) + + +def _category_from_class_name(exc: BaseException) -> LLMErrorCategory | None: + names = _class_names(exc) + for category, hints in _CLASS_NAME_MAP: + if any(any(hint in name for hint in hints) for name in names): + return category + return None + + +def _extract_provider_status_code(parsed: dict[str, Any] | None) -> int | None: + if not isinstance(parsed, dict): + return None + candidates: list[Any] = [parsed.get("code"), parsed.get("status")] + nested = parsed.get("error") + if isinstance(nested, dict): + candidates.extend([nested.get("code"), nested.get("status")]) + for value in candidates: + try: + if value is None: + continue + return int(value) + except Exception: + continue + return None + + +def _extract_provider_error_type(parsed: dict[str, Any] | None) -> str | None: + if not isinstance(parsed, dict): + return None + candidates: list[Any] = [parsed.get("type")] + nested = parsed.get("error") + if isinstance(nested, dict): + candidates.append(nested.get("type")) + for value in candidates: + if isinstance(value, str) and value: + return value + return None + + +def _category_from_provider_payload( + status_code: int | None, + provider_error_type: str | None, +) -> LLMErrorCategory | None: + if status_code == 429: + return LLMErrorCategory.RATE_LIMITED + if status_code == 401: + return LLMErrorCategory.AUTH_FAILED + if status_code == 403: + return LLMErrorCategory.PERMISSION_DENIED + if status_code == 404: + return LLMErrorCategory.MODEL_NOT_FOUND + if status_code in (400, 422): + return LLMErrorCategory.BAD_REQUEST + if status_code in (502, 504): + return LLMErrorCategory.BAD_GATEWAY + if status_code == 503: + return LLMErrorCategory.PROVIDER_UNAVAILABLE + if status_code is not None and status_code >= 500: + return LLMErrorCategory.SERVER_ERROR + + normalized_type = (provider_error_type or "").lower() + if normalized_type == "rate_limit_error": + return LLMErrorCategory.RATE_LIMITED + if normalized_type in {"authentication_error", "invalid_api_key", "invalid_api_key_error"}: + return LLMErrorCategory.AUTH_FAILED + if normalized_type in {"permission_denied", "forbidden"}: + return LLMErrorCategory.PERMISSION_DENIED + if normalized_type in {"not_found_error", "model_not_found"}: + return LLMErrorCategory.MODEL_NOT_FOUND + if normalized_type in {"context_length_exceeded", "context_window_exceeded"}: + return LLMErrorCategory.CONTEXT_LIMIT + return None + + +def _category_from_message(raw: str) -> LLMErrorCategory | None: + lowered = raw.lower() + if any(hint in lowered for hint in ("rate limit", "rate-limited", "temporarily rate-limited")): + return LLMErrorCategory.RATE_LIMITED + if any( + hint in lowered + for hint in ( + "invalid api key", + "invalid_api_key", + "authentication", + "unauthorized", + "user not found", + "api key is expired", + "expired api key", + ) + ): + return LLMErrorCategory.AUTH_FAILED + if "forbidden" in lowered or "permission denied" in lowered: + return LLMErrorCategory.PERMISSION_DENIED + if "model not found" in lowered: + return LLMErrorCategory.MODEL_NOT_FOUND + if any( + hint in lowered + for hint in ( + "context length", + "context window", + "maximum context", + "too many tokens", + ) + ): + return LLMErrorCategory.CONTEXT_LIMIT + return None + + +def adapt_llm_exception(exc: BaseException) -> LLMErrorAdaptation: + raw = str(exc) + parsed = _parse_error_payload(raw) + status_code = _extract_provider_status_code(parsed) + provider_error_type = _extract_provider_error_type(parsed) + + category = ( + _category_from_provider_payload(status_code, provider_error_type) + or _category_from_message(raw) + or _category_from_class_name(exc) + or LLMErrorCategory.UNKNOWN + ) + return LLMErrorAdaptation( + category=category, + retryable=category in _RETRYABLE_CATEGORIES, + user_message=_CATEGORY_MESSAGES[category], + provider_status_code=status_code, + provider_error_type=provider_error_type, + ) + + +def llm_error_message(exc: BaseException) -> str: + return adapt_llm_exception(exc).user_message + diff --git a/surfsense_backend/app/tasks/chat/streaming/errors/classifier.py b/surfsense_backend/app/tasks/chat/streaming/errors/classifier.py index 6b37df343..269143af2 100644 --- a/surfsense_backend/app/tasks/chat/streaming/errors/classifier.py +++ b/surfsense_backend/app/tasks/chat/streaming/errors/classifier.py @@ -12,6 +12,7 @@ from app.agents.chat.multi_agent_chat.main_agent.middleware.busy_mutex import ( is_cancel_requested, ) from app.agents.chat.runtime.errors import BusyError +from app.services.llm_error_adapter import LLMErrorCategory, adapt_llm_exception TURN_CANCELLING_INITIAL_DELAY_MS = 200 TURN_CANCELLING_BACKOFF_FACTOR = 2 @@ -102,6 +103,9 @@ def _extract_provider_error_code(parsed: dict[str, Any] | None) -> int | None: def is_provider_rate_limited(exc: BaseException) -> bool: """Return True if the exception looks like an upstream HTTP 429 / rate limit.""" + if adapt_llm_exception(exc).category is LLMErrorCategory.RATE_LIMITED: + return True + raw = str(exc) lowered = raw.lower() if "ratelimit" in type(exc).__name__.lower(): @@ -131,6 +135,84 @@ def is_provider_rate_limited(exc: BaseException) -> bool: ) +def _provider_error_extra(adapted: Any) -> dict[str, Any] | None: + extra: dict[str, Any] = {"provider_error_category": adapted.category.value} + if adapted.provider_status_code is not None: + extra["provider_status_code"] = adapted.provider_status_code + if adapted.provider_error_type: + extra["provider_error_type"] = adapted.provider_error_type + return extra + + +def _classify_provider_exception( + exc: Exception, +) -> tuple[ + str, str, Literal["info", "warn", "error"], bool, str, dict[str, Any] | None +] | None: + adapted = adapt_llm_exception(exc) + + if adapted.category is LLMErrorCategory.RATE_LIMITED: + return ( + "rate_limited", + "RATE_LIMITED", + "warn", + True, + "This model is temporarily rate-limited. Please try again in a few seconds or switch models.", + _provider_error_extra(adapted), + ) + + if adapted.category in { + LLMErrorCategory.AUTH_FAILED, + LLMErrorCategory.PERMISSION_DENIED, + }: + return ( + "model_auth_failed", + "MODEL_AUTH_FAILED", + "warn", + True, + "This model's API key is invalid or expired. Switch models, or update the API key.", + _provider_error_extra(adapted), + ) + + if adapted.category is LLMErrorCategory.MODEL_NOT_FOUND: + return ( + "model_not_found", + "MODEL_NOT_FOUND", + "warn", + True, + "The selected model is unavailable or no longer exists. Switch to another model and try again.", + _provider_error_extra(adapted), + ) + + if adapted.category is LLMErrorCategory.CONTEXT_LIMIT: + return ( + "model_context_limit", + "MODEL_CONTEXT_LIMIT", + "warn", + True, + "This request is too large for the selected model. Try a model with a larger context window or reduce the input.", + _provider_error_extra(adapted), + ) + + if adapted.category in { + LLMErrorCategory.TIMEOUT, + LLMErrorCategory.PROVIDER_UNAVAILABLE, + LLMErrorCategory.BAD_GATEWAY, + LLMErrorCategory.CONNECTION_FAILED, + LLMErrorCategory.SERVER_ERROR, + }: + return ( + "model_provider_unavailable", + "MODEL_PROVIDER_UNAVAILABLE", + "warn", + True, + "The selected model provider is temporarily unavailable. Please try again or switch models.", + _provider_error_extra(adapted), + ) + + return None + + def classify_stream_exception( exc: Exception, *, @@ -167,15 +249,9 @@ def classify_stream_exception( None, ) - if is_provider_rate_limited(exc): - return ( - "rate_limited", - "RATE_LIMITED", - "warn", - True, - "This model is temporarily rate-limited. Please try again in a few seconds or switch models.", - None, - ) + provider_classification = _classify_provider_exception(exc) + if provider_classification is not None: + return provider_classification return ( "server_error", diff --git a/surfsense_backend/tests/unit/tasks/chat/streaming/test_error_classifier.py b/surfsense_backend/tests/unit/tasks/chat/streaming/test_error_classifier.py new file mode 100644 index 000000000..48b07596c --- /dev/null +++ b/surfsense_backend/tests/unit/tasks/chat/streaming/test_error_classifier.py @@ -0,0 +1,80 @@ +from __future__ import annotations + +import pytest + +from app.services.llm_error_adapter import LLMErrorCategory, adapt_llm_exception +from app.tasks.chat.streaming.errors.classifier import classify_stream_exception + +pytestmark = pytest.mark.unit + + +def _exception_named(name: str, message: str) -> Exception: + return type(name, (Exception,), {})(message) + + +def test_adapter_classifies_authentication_error_by_class_name() -> None: + exc = _exception_named("AuthenticationError", "provider rejected credentials") + + adapted = adapt_llm_exception(exc) + + assert adapted.category is LLMErrorCategory.AUTH_FAILED + assert adapted.retryable is False + assert adapted.user_message == "LLM authentication failed. Check your API key." + + +def test_adapter_classifies_embedded_provider_401_payload() -> None: + exc = RuntimeError( + 'litellm.AuthenticationError: OpenrouterException - {"error":{"message":"User not found.","code":401}}' + ) + + adapted = adapt_llm_exception(exc) + + assert adapted.category is LLMErrorCategory.AUTH_FAILED + assert adapted.provider_status_code == 401 + + +def test_adapter_preserves_rate_limit_classification() -> None: + exc = RuntimeError('{"error":{"message":"Slow down","code":429}}') + + adapted = adapt_llm_exception(exc) + + assert adapted.category is LLMErrorCategory.RATE_LIMITED + assert adapted.retryable is True + + +def test_stream_classifier_maps_model_auth_to_stable_code() -> None: + exc = RuntimeError( + 'litellm.AuthenticationError: OpenrouterException - {"error":{"message":"User not found.","code":401}}' + ) + + kind, code, severity, expected, message, extra = classify_stream_exception( + exc, + flow_label="chat", + ) + + assert kind == "model_auth_failed" + assert code == "MODEL_AUTH_FAILED" + assert severity == "warn" + assert expected is True + assert "API key" in message + assert extra == { + "provider_error_category": "auth_failed", + "provider_status_code": 401, + } + + +def test_stream_classifier_keeps_unknown_errors_generic() -> None: + exc = RuntimeError("database exploded") + + kind, code, severity, expected, message, extra = classify_stream_exception( + exc, + flow_label="chat", + ) + + assert kind == "server_error" + assert code == "SERVER_ERROR" + assert severity == "error" + assert expected is False + assert message == "Error during chat: database exploded" + assert extra is None + diff --git a/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx b/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx index f048376cc..0c4fa63ec 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx @@ -613,6 +613,18 @@ export default function NewChatPage() { return; } + if (normalized.channel === "inline") { + if (normalized.assistantMessage) { + await persistAssistantErrorMessage({ + threadId, + assistantMsgId, + text: normalized.assistantMessage, + }); + } + toast.error(normalized.userMessage); + return; + } + toast.error(normalized.userMessage); }, [currentUser?.id, persistAssistantErrorMessage, searchSpaceId, setPremiumAlertForThread] diff --git a/surfsense_web/components/free-chat/free-chat-page.tsx b/surfsense_web/components/free-chat/free-chat-page.tsx index b28b1e0a1..8d5215fca 100644 --- a/surfsense_web/components/free-chat/free-chat-page.tsx +++ b/surfsense_web/components/free-chat/free-chat-page.tsx @@ -63,6 +63,21 @@ function normalizeFreeChatErrorMessage(error: unknown): string { if (code === "THREAD_BUSY") { return "A previous response is still stopping. Please try again in a moment."; } + if (code === "MODEL_AUTH_FAILED") { + return "This model’s API key is invalid or expired. Switch models, or update the API key."; + } + if (code === "MODEL_NOT_FOUND") { + return "This model is unavailable or no longer exists. Please switch models."; + } + if (code === "MODEL_CONTEXT_LIMIT") { + return "This request is too large for the selected model. Reduce the input or switch models."; + } + if (code === "MODEL_PROVIDER_UNAVAILABLE") { + return "The selected model provider is temporarily unavailable. Please try again or switch models."; + } + if (code === "RATE_LIMITED") { + return "This model is temporarily rate-limited. Please try again in a few seconds or switch models."; + } return error.message || "An unexpected error occurred"; } @@ -154,7 +169,7 @@ export function FreeChatPage() { assistantMsgId: string, signal: AbortSignal, turnstileToken: string | null - ): Promise<"captcha" | void> => { + ): Promise<"captcha" | undefined> => { const reqBody: Record = { model_slug: modelSlug, messages: messageHistory, diff --git a/surfsense_web/lib/chat/chat-error-classifier.ts b/surfsense_web/lib/chat/chat-error-classifier.ts index 1c67d59a1..92924f0f7 100644 --- a/surfsense_web/lib/chat/chat-error-classifier.ts +++ b/surfsense_web/lib/chat/chat-error-classifier.ts @@ -5,6 +5,10 @@ export type ChatErrorKind = | "thread_busy" | "send_failed_pre_accept" | "auth_expired" + | "model_auth_failed" + | "model_not_found" + | "model_context_limit" + | "model_provider_unavailable" | "rate_limited" | "network_offline" | "stream_interrupted" @@ -14,7 +18,7 @@ export type ChatErrorKind = | "server_error" | "unknown"; -export type ChatErrorChannel = "pinned_inline" | "toast" | "silent"; +export type ChatErrorChannel = "pinned_inline" | "inline" | "toast" | "silent"; export type ChatTelemetryEvent = "chat_blocked" | "chat_error"; export type ChatErrorSeverity = "info" | "warn" | "error"; @@ -206,6 +210,66 @@ export function classifyChatError(input: RawChatErrorInput): NormalizedChatError }; } + if (errorCode === "MODEL_AUTH_FAILED") { + return { + kind: "model_auth_failed", + channel: "toast", + severity: "warn", + telemetryEvent: "chat_blocked", + isExpected: true, + userMessage: + "This model’s API key is invalid or expired. Switch models, or update the API key.", + rawMessage, + errorCode: errorCode ?? "MODEL_AUTH_FAILED", + details: { flow: input.flow, providerErrorType }, + }; + } + + if (errorCode === "MODEL_NOT_FOUND") { + return { + kind: "model_not_found", + channel: "toast", + severity: "warn", + telemetryEvent: "chat_blocked", + isExpected: true, + userMessage: + "This model is unavailable or no longer exists. Switch to another model and try again.", + rawMessage, + errorCode: errorCode ?? "MODEL_NOT_FOUND", + details: { flow: input.flow, providerErrorType }, + }; + } + + if (errorCode === "MODEL_CONTEXT_LIMIT") { + return { + kind: "model_context_limit", + channel: "toast", + severity: "warn", + telemetryEvent: "chat_blocked", + isExpected: true, + userMessage: + "This request is too large for the selected model. Reduce the input or switch models.", + rawMessage, + errorCode: errorCode ?? "MODEL_CONTEXT_LIMIT", + details: { flow: input.flow, providerErrorType }, + }; + } + + if (errorCode === "MODEL_PROVIDER_UNAVAILABLE") { + return { + kind: "model_provider_unavailable", + channel: "toast", + severity: "warn", + telemetryEvent: "chat_blocked", + isExpected: true, + userMessage: + "The selected model provider is temporarily unavailable. Please try again or switch models.", + rawMessage, + errorCode: errorCode ?? "MODEL_PROVIDER_UNAVAILABLE", + details: { flow: input.flow, providerErrorType }, + }; + } + if (errorCode === "RATE_LIMITED" || providerTypeNormalized === "rate_limit_error") { return { kind: "rate_limited", diff --git a/surfsense_web/lib/chat/chat-request-errors.ts b/surfsense_web/lib/chat/chat-request-errors.ts index e0dfb3cc4..c86c72d66 100644 --- a/surfsense_web/lib/chat/chat-request-errors.ts +++ b/surfsense_web/lib/chat/chat-request-errors.ts @@ -91,6 +91,10 @@ export function tagPreAcceptSendFailure(error: unknown): unknown { "TURN_CANCELLING", "AUTH_EXPIRED", "UNAUTHORIZED", + "MODEL_AUTH_FAILED", + "MODEL_NOT_FOUND", + "MODEL_CONTEXT_LIMIT", + "MODEL_PROVIDER_UNAVAILABLE", "RATE_LIMITED", "NETWORK_ERROR", "STREAM_PARSE_ERROR", From ad404b2dbc85ad44a25deea5c0b1979785379911 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Fri, 12 Jun 2026 05:15:15 +0530 Subject: [PATCH 028/212] refactor(icons): replace Workflow icon with Clock3 across automation components --- .../automations/components/automations-empty-state.tsx | 4 ++-- .../automations/components/automations-table.tsx | 4 ++-- .../components/layout/providers/LayoutDataProvider.tsx | 4 ++-- surfsense_web/components/new-chat/chat-example-prompts.tsx | 4 ++-- .../components/tool-ui/automation/create-automation.tsx | 6 +++--- surfsense_web/contracts/enums/toolIcons.tsx | 4 ++-- 6 files changed, 13 insertions(+), 13 deletions(-) diff --git a/surfsense_web/app/dashboard/[search_space_id]/automations/components/automations-empty-state.tsx b/surfsense_web/app/dashboard/[search_space_id]/automations/components/automations-empty-state.tsx index b2e7b2532..70d9990f8 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/automations/components/automations-empty-state.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/automations/components/automations-empty-state.tsx @@ -1,5 +1,5 @@ "use client"; -import { Workflow } from "lucide-react"; +import { Clock3 } from "lucide-react"; import Link from "next/link"; import { Button } from "@/components/ui/button"; @@ -18,7 +18,7 @@ export function AutomationsEmptyState({ searchSpaceId, canCreate }: AutomationsE return (
- +

No automations yet

diff --git a/surfsense_web/app/dashboard/[search_space_id]/automations/components/automations-table.tsx b/surfsense_web/app/dashboard/[search_space_id]/automations/components/automations-table.tsx index 8314a5179..727636b43 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/automations/components/automations-table.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/automations/components/automations-table.tsx @@ -1,5 +1,5 @@ "use client"; -import { CalendarDays, Info, Workflow } from "lucide-react"; +import { CalendarDays, Clock3, Info } from "lucide-react"; import { Table, TableBody, TableHead, TableHeader, TableRow } from "@/components/ui/table"; import type { AutomationSummary } from "@/contracts/types/automation.types"; import { AutomationRow } from "./automation-row"; @@ -31,7 +31,7 @@ export function AutomationsTable({ - + Name diff --git a/surfsense_web/components/layout/providers/LayoutDataProvider.tsx b/surfsense_web/components/layout/providers/LayoutDataProvider.tsx index 549e6e7d7..5c62f6a7d 100644 --- a/surfsense_web/components/layout/providers/LayoutDataProvider.tsx +++ b/surfsense_web/components/layout/providers/LayoutDataProvider.tsx @@ -2,7 +2,7 @@ import { useQuery } from "@tanstack/react-query"; import { useAtom, useAtomValue, useSetAtom } from "jotai"; -import { AlertTriangle, Inbox, LibraryBig, Workflow } from "lucide-react"; +import { AlertTriangle, Clock3, Inbox, LibraryBig } from "lucide-react"; import { useParams, usePathname, useRouter } from "next/navigation"; import { useTranslations } from "next-intl"; import { useTheme } from "next-themes"; @@ -342,7 +342,7 @@ export function LayoutDataProvider({ searchSpaceId, children }: LayoutDataProvid { title: "Automations", url: `/dashboard/${searchSpaceId}/automations`, - icon: Workflow, + icon: Clock3, isActive: isAutomationsActive, }, isMobile diff --git a/surfsense_web/components/new-chat/chat-example-prompts.tsx b/surfsense_web/components/new-chat/chat-example-prompts.tsx index 98d95b98b..4fdc32a92 100644 --- a/surfsense_web/components/new-chat/chat-example-prompts.tsx +++ b/surfsense_web/components/new-chat/chat-example-prompts.tsx @@ -1,12 +1,12 @@ "use client"; import { + Clock3, FilePlus2, Search, Settings2, type LucideIcon, WandSparkles, - Workflow, X, } from "lucide-react"; import { memo, useCallback, useState } from "react"; @@ -22,7 +22,7 @@ interface ChatExamplePromptsProps { const CATEGORY_ICONS: Record = { search: Search, create: FilePlus2, - automate: Workflow, + automate: Clock3, tools: Settings2, }; diff --git a/surfsense_web/components/tool-ui/automation/create-automation.tsx b/surfsense_web/components/tool-ui/automation/create-automation.tsx index 2a7d09f53..644ccd822 100644 --- a/surfsense_web/components/tool-ui/automation/create-automation.tsx +++ b/surfsense_web/components/tool-ui/automation/create-automation.tsx @@ -2,7 +2,7 @@ import type { ToolCallMessagePartProps } from "@assistant-ui/react"; import { useAtomValue } from "jotai"; -import { AlertCircle, CornerDownLeftIcon, ExternalLink, Pencil, Workflow } from "lucide-react"; +import { AlertCircle, Clock3, CornerDownLeftIcon, ExternalLink, Pencil } from "lucide-react"; import Link from "next/link"; import { useCallback, useEffect, useMemo, useRef, useState } from "react"; import { @@ -211,7 +211,7 @@ function ApprovalCard({ args, interruptData, onDecision }: ApprovalCardProps) {

- +

{phase === "rejected" @@ -404,7 +404,7 @@ function SavedCard({ result }: { result: SavedResult }) { return (

- +

Automation saved

{result.name}

diff --git a/surfsense_web/contracts/enums/toolIcons.tsx b/surfsense_web/contracts/enums/toolIcons.tsx index 494c0eaee..496c26577 100644 --- a/surfsense_web/contracts/enums/toolIcons.tsx +++ b/surfsense_web/contracts/enums/toolIcons.tsx @@ -1,6 +1,7 @@ import { Brain, Calendar, + Clock3, FileEdit, FilePlus, FileText, @@ -24,7 +25,6 @@ import { SearchCheck, Send, Trash2, - Workflow, Wrench, } from "lucide-react"; @@ -47,7 +47,7 @@ const TOOL_ICONS: Record = { scrape_webpage: ScanLine, web_search: Globe, // Automations - create_automation: Workflow, + create_automation: Clock3, // Memory update_memory: Brain, // Filesystem (built-in deepagent + middleware) From ced1bb85edc778e710ff133fe5265df089a40a11 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Fri, 12 Jun 2026 09:43:56 +0530 Subject: [PATCH 029/212] feat(model-connections): implement bulk model update endpoint and related schema changes --- .../app/routes/model_connections_routes.py | 31 +- surfsense_backend/app/schemas/__init__.py | 1 + .../app/schemas/model_connections.py | 6 + .../model-connections-mutation.atoms.ts | 12 + .../settings/model-connections-settings.tsx | 626 +++++++++++++----- .../types/model-connections.types.ts | 7 + .../lib/apis/model-connections-api.service.ts | 23 +- 7 files changed, 538 insertions(+), 168 deletions(-) diff --git a/surfsense_backend/app/routes/model_connections_routes.py b/surfsense_backend/app/routes/model_connections_routes.py index ecb86711e..730c68565 100644 --- a/surfsense_backend/app/routes/model_connections_routes.py +++ b/surfsense_backend/app/routes/model_connections_routes.py @@ -1,7 +1,7 @@ import logging from fastapi import APIRouter, Depends, HTTPException -from sqlalchemy import select +from sqlalchemy import select, update from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import selectinload @@ -25,6 +25,7 @@ from app.schemas import ( ModelRead, ModelRolesRead, ModelRolesUpdate, + ModelsBulkUpdate, ModelUpdate, VerifyConnectionResponse, ) @@ -62,6 +63,7 @@ def _connection_read(conn: Connection | dict, models: list[Model | dict] | None id=conn.id, provider=conn.provider, base_url=conn.base_url, + api_key=conn.api_key, extra=conn.extra or {}, scope=conn.scope, search_space_id=conn.search_space_id, @@ -351,6 +353,33 @@ async def add_manual_model( return _model_read(model) +@router.patch("/model-connections/{connection_id}/models", response_model=list[ModelRead]) +async def bulk_update_models( + connection_id: int, + data: ModelsBulkUpdate, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +): + conn = await _load_connection(session, connection_id) + await _assert_connection_access(session, user, conn, Permission.LLM_CONFIGS_UPDATE.value) + + model_ids = set(data.model_ids) + await session.execute( + update(Model) + .where(Model.connection_id == connection_id, Model.id.in_(model_ids)) + .values(enabled=data.enabled) + ) + await session.commit() + session.expire_all() + + result = await session.execute( + select(Model) + .where(Model.connection_id == connection_id, Model.id.in_(model_ids)) + .order_by(Model.id) + ) + return [_model_read(model) for model in result.scalars().all()] + + @router.put("/models/{model_id}", response_model=ModelRead) async def update_model( model_id: int, diff --git a/surfsense_backend/app/schemas/__init__.py b/surfsense_backend/app/schemas/__init__.py index 8ac7c5bbb..55e712f12 100644 --- a/surfsense_backend/app/schemas/__init__.py +++ b/surfsense_backend/app/schemas/__init__.py @@ -53,6 +53,7 @@ from .model_connections import ( ModelRead, ModelRolesRead, ModelRolesUpdate, + ModelsBulkUpdate, ModelUpdate, VerifyConnectionResponse, ) diff --git a/surfsense_backend/app/schemas/model_connections.py b/surfsense_backend/app/schemas/model_connections.py index c081a193d..0b03c7fab 100644 --- a/surfsense_backend/app/schemas/model_connections.py +++ b/surfsense_backend/app/schemas/model_connections.py @@ -32,6 +32,7 @@ class ConnectionRead(BaseModel): id: int provider: str base_url: str | None = None + api_key: str | None = None extra: dict[str, Any] = Field(default_factory=dict) scope: ConnectionScope | str search_space_id: int | None = None @@ -87,6 +88,11 @@ class ModelUpdate(BaseModel): capabilities_override: dict[str, Any] | None = None +class ModelsBulkUpdate(BaseModel): + model_ids: list[int] = Field(..., min_length=1, max_length=1000) + enabled: bool + + class ModelProviderRead(BaseModel): provider: str transport: str diff --git a/surfsense_web/atoms/model-connections/model-connections-mutation.atoms.ts b/surfsense_web/atoms/model-connections/model-connections-mutation.atoms.ts index 101bad1b5..fee3b95ba 100644 --- a/surfsense_web/atoms/model-connections/model-connections-mutation.atoms.ts +++ b/surfsense_web/atoms/model-connections/model-connections-mutation.atoms.ts @@ -6,6 +6,7 @@ import type { ModelCreateRequest, ModelRead, ModelRoles, + ModelsBulkUpdateRequest, ModelUpdateRequest, VerifyConnectionResponse, } from "@/contracts/types/model-connections.types"; @@ -127,6 +128,17 @@ export const updateModelMutationAtom = atomWithMutation((get) => { }; }); +export const bulkUpdateModelsMutationAtom = atomWithMutation((get) => { + const searchSpaceId = Number(get(activeSearchSpaceIdAtom)); + return { + mutationKey: ["models", "bulk-update"], + mutationFn: ({ connectionId, data }: { connectionId: number; data: ModelsBulkUpdateRequest }) => + modelConnectionsApiService.bulkUpdateModels(connectionId, data), + onSuccess: () => invalidateModelConnections(searchSpaceId), + onError: (error: Error) => toast.error(error.message || "Failed to update models"), + }; +}); + export const testModelMutationAtom = atomWithMutation((get) => { const searchSpaceId = Number(get(activeSearchSpaceIdAtom)); return { diff --git a/surfsense_web/components/settings/model-connections-settings.tsx b/surfsense_web/components/settings/model-connections-settings.tsx index 0e541548b..9112cfe64 100644 --- a/surfsense_web/components/settings/model-connections-settings.tsx +++ b/surfsense_web/components/settings/model-connections-settings.tsx @@ -1,14 +1,24 @@ "use client"; import { useAtom, useAtomValue } from "jotai"; -import { CheckCircle2, PlugZap, Plus, RefreshCcw, Trash2, XCircle } from "lucide-react"; +import { + Check, + CheckCircle2, + ChevronsUpDown, + Eye, + EyeOff, + RefreshCcw, + Settings, + Trash2, + XCircle, +} from "lucide-react"; import { useState } from "react"; import { addManualModelMutationAtom, + bulkUpdateModelsMutationAtom, createModelConnectionMutationAtom, deleteModelConnectionMutationAtom, discoverConnectionModelsMutationAtom, - testModelMutationAtom, updateModelConnectionMutationAtom, updateModelMutationAtom, updateModelRolesMutationAtom, @@ -20,11 +30,41 @@ import { modelProvidersAtom, modelRolesAtom, } from "@/atoms/model-connections/model-connections-query.atoms"; +import { + AlertDialog, + AlertDialogAction, + AlertDialogCancel, + AlertDialogContent, + AlertDialogDescription, + AlertDialogFooter, + AlertDialogHeader, + AlertDialogTitle, + AlertDialogTrigger, +} from "@/components/ui/alert-dialog"; import { Badge } from "@/components/ui/badge"; import { Button } from "@/components/ui/button"; import { Card, CardContent, CardDescription, CardHeader, CardTitle } from "@/components/ui/card"; +import { Checkbox } from "@/components/ui/checkbox"; +import { + Command, + CommandEmpty, + CommandGroup, + CommandInput, + CommandItem, + CommandList, +} from "@/components/ui/command"; +import { + Dialog, + DialogContent, + DialogDescription, + DialogFooter, + DialogHeader, + DialogTitle, + DialogTrigger, +} from "@/components/ui/dialog"; import { Input } from "@/components/ui/input"; import { Label } from "@/components/ui/label"; +import { Popover, PopoverContent, PopoverTrigger } from "@/components/ui/popover"; import { Select, SelectContent, @@ -32,8 +72,14 @@ import { SelectTrigger, SelectValue, } from "@/components/ui/select"; -import type { ConnectionRead, ModelRead } from "@/contracts/types/model-connections.types"; +import { Separator } from "@/components/ui/separator"; +import type { + ConnectionRead, + ConnectionUpdateRequest, + ModelRead, +} from "@/contracts/types/model-connections.types"; import { getProviderIcon } from "@/lib/provider-icons"; +import { cn } from "@/lib/utils"; // Free-text URL hints (datalist), mirroring OpenWebUI. These never restrict // what the user can type — any OpenAI-compatible endpoint works. @@ -69,6 +115,67 @@ const MODEL_CAPABILITY_FILTERS: { key: ModelCapabilityFilter; label: string }[] { key: "image_gen", label: "Image" }, ]; +function UrlSuggestionCombobox({ + value, + onChange, + placeholder, +}: { + value: string; + onChange: (value: string) => void; + placeholder: string; +}) { + const [open, setOpen] = useState(false); + + return ( + + + + + + + + + + Use the custom URL you typed + + + {URL_SUGGESTIONS.map((url) => ( + { + onChange(url); + setOpen(false); + }} + > + + {url} + + ))} + + + + + + ); +} + function StatusBadge({ connection }: { connection: ConnectionRead }) { if (connection.last_status === "OK") { return ( @@ -105,11 +212,15 @@ function ConnectionCard({ connection }: { connection: ConnectionRead }) { const deleteConnection = useAtomValue(deleteModelConnectionMutationAtom); const addManualModel = useAtomValue(addManualModelMutationAtom); const updateModel = useAtomValue(updateModelMutationAtom); - const testModel = useAtomValue(testModelMutationAtom); + const bulkUpdateModels = useAtomValue(bulkUpdateModelsMutationAtom); const allowlist = Array.isArray(connection.extra?.model_ids) ? (connection.extra.model_ids as string[]) : []; + const [isSettingsOpen, setIsSettingsOpen] = useState(false); + const [baseUrlDraft, setBaseUrlDraft] = useState(connection.base_url ?? ""); + const [apiKeyDraft, setApiKeyDraft] = useState(""); + const [showApiKey, setShowApiKey] = useState(false); const [allowlistText, setAllowlistText] = useState(allowlist.join(", ")); const [manualModelId, setManualModelId] = useState(""); const [modelFilter, setModelFilter] = useState(null); @@ -122,6 +233,38 @@ function ConnectionCard({ connection }: { connection: ConnectionRead }) { const filteredModels = modelFilter ? connection.models.filter((model) => capability(model, modelFilter)) : connection.models; + const allFilteredModelsEnabled = + filteredModels.length > 0 && filteredModels.every((model) => model.enabled); + const hasConnectionChanges = + baseUrlDraft.trim() !== (connection.base_url ?? "") || + apiKeyDraft.trim() !== (connection.api_key ?? ""); + + function handleSettingsOpenChange(open: boolean) { + setIsSettingsOpen(open); + if (open) { + setBaseUrlDraft(connection.base_url ?? ""); + setApiKeyDraft(connection.api_key ?? ""); + setShowApiKey(false); + setAllowlistText(allowlist.join(", ")); + } + } + + function saveConnectionSettings() { + const data: ConnectionUpdateRequest = { + base_url: baseUrlDraft.trim() || null, + }; + + if (apiKeyDraft.trim() !== (connection.api_key ?? "")) { + data.api_key = apiKeyDraft.trim() || null; + } + + updateConnection.mutate( + { id: connection.id, data }, + { + onSuccess: () => setApiKeyDraft(""), + } + ); + } function saveAllowlist() { const ids = allowlistText @@ -144,170 +287,321 @@ function ConnectionCard({ connection }: { connection: ConnectionRead }) { } function deleteCurrentConnection() { - const confirmed = window.confirm( - `Delete the ${providerLabel} connection and all of its models? This cannot be undone.` - ); - if (!confirmed) return; deleteConnection.mutate(connection.id); } + function toggleFilteredModels() { + const nextEnabled = !allFilteredModelsEnabled; + const modelIds = filteredModels + .filter((model) => model.enabled !== nextEnabled) + .map((model) => model.id); + + if (modelIds.length === 0) return; + + bulkUpdateModels.mutate({ + connectionId: connection.id, + data: { model_ids: modelIds, enabled: nextEnabled }, + }); + } + return ( -
-
-
-
+
+
+
+
{getProviderIcon(providerLabel, { className: "size-4" })} - {providerLabel} + {providerLabel} + {connection.scope === "GLOBAL" ? ( + + Default + + ) : null}
-
+
{connection.base_url || "Provider default endpoint"}
-
+
- - - -
-
- - {connection.last_status && connection.last_status !== "OK" ? ( -

- {connection.last_error || "Could not list models."} Chat may still work — add model IDs - manually below. -

- ) : null} - - {!isLocal ? ( -
- -
- setAllowlistText(event.target.value)} - placeholder="Comma-separated, e.g. anthropic/claude-sonnet-4-5, google/gemini-2.5-pro" - /> - -
-

- Leave empty to discover all models. Recommended for providers with large catalogs (e.g. - OpenRouter). -

-
- ) : null} - -
- setManualModelId(event.target.value)} - onKeyDown={(event) => { - if (event.key === "Enter") { - event.preventDefault(); - addModel(); - } - }} - placeholder="Add a model ID manually (for providers without /models)" - /> - -
- - {connection.models.length > 0 ? ( -
- Filter models - {MODEL_CAPABILITY_FILTERS.map((filter) => { - const count = connection.models.filter((model) => capability(model, filter.key)).length; - const isActive = modelFilter === filter.key; - - return ( - - ); - })} -
- ) : null} + + + +
+ {getProviderIcon(providerLabel, { className: "size-5" })} +
+ + Configure {providerLabel} + + + Manage credentials and choose which models are available from this provider. + +
+
+
-
- {filteredModels.length === 0 && modelFilter ? ( -
- No {MODEL_CAPABILITY_FILTERS.find((filter) => filter.key === modelFilter)?.label.toLowerCase()}{" "} - models found on this connection. -
- ) : null} - {filteredModels.map((model) => ( -
-
-
- {getProviderIcon(providerLabel, { className: "size-4" })} - {modelLabel(model)} - {model.source === "MANUAL" ? ( - - manual - - ) : null} +
+
+
+ + +

+ Leave empty to use the provider default endpoint. +

+
+ +
+ +
+ setApiKeyDraft(event.target.value)} + placeholder={connection.has_api_key ? "Saved API key" : "Paste an API key"} + type={showApiKey ? "text" : "password"} + className="pr-11" + /> + +
+
+ + {!isLocal ? ( +
+ +
+ setAllowlistText(event.target.value)} + placeholder="Comma-separated, e.g. anthropic/claude-sonnet-4-5, google/gemini-2.5-pro" + /> + +
+

+ Leave empty to discover all models. Recommended for providers with large + catalogs. +

+
+ ) : null} + + + +
+
+
+
Models
+

+ Select models to make available for this provider. +

+
+
+ + +
+
+ +
+ setManualModelId(event.target.value)} + onKeyDown={(event) => { + if (event.key === "Enter") { + event.preventDefault(); + addModel(); + } + }} + placeholder="Add a model ID manually" + /> + +
+ + {connection.models.length > 0 ? ( +
+ + Filter models + + {MODEL_CAPABILITY_FILTERS.map((filter) => { + const count = connection.models.filter((model) => + capability(model, filter.key) + ).length; + const isActive = modelFilter === filter.key; + + return ( + + ); + })} +
+ ) : null} + +
+ {connection.models.length === 0 ? ( +
+ No models yet. Use the refresh button to discover models or add one + manually. +
+ ) : null} + {filteredModels.length === 0 && modelFilter ? ( +
+ No{" "} + {MODEL_CAPABILITY_FILTERS.find( + (filter) => filter.key === modelFilter + )?.label.toLowerCase()}{" "} + models found on this connection. +
+ ) : null} +
+ {filteredModels.map((model) => ( +
+ + updateModel.mutate({ + id: model.id, + data: { enabled: checked === true }, + }) + } + disabled={updateModel.isPending} + /> +
+
+ {modelLabel(model)} + {model.source === "MANUAL" ? ( + + manual + + ) : null} +
+
+ {["chat", "vision", "image_gen"] + .filter((key) => + capability(model, key as "chat" | "vision" | "image_gen") + ) + .join(", ") || "No discovered capabilities"} +
+
+
+ ))} +
+
+
+ + {connection.last_status && connection.last_status !== "OK" ? ( +

+ {connection.last_error || "Could not list models."} Chat may still work; add + model IDs manually if discovery is unavailable. +

+ ) : null} +
-
- {["chat", "vision", "image_gen"] - .filter((key) => capability(model, key as "chat" | "vision" | "image_gen")) - .join(", ") || "No discovered capabilities"} -
-
-
- + + + + + + + + + -
-
- ))} + + + + Delete this provider? + + {providerLabel} and all of + its models will be removed from this search space. This cannot be undone. + + + + Cancel + + Delete + + + + +
); @@ -394,19 +688,13 @@ export function ModelConnectionsSettings({ searchSpaceId }: { searchSpaceId: num
- setBaseUrl(event.target.value)} + onChange={setBaseUrl} placeholder={ isOllama ? "http://host.docker.internal:11434" : "https://api.example.com/v1" } - list="model-conn-url-suggestions" /> - - {URL_SUGGESTIONS.map((url) => ( -
@@ -425,7 +713,7 @@ export function ModelConnectionsSettings({ searchSpaceId }: { searchSpaceId: num Boolean(selectedProvider?.base_url_required && !baseUrl.trim()) } > - Add + Add
@@ -439,11 +727,17 @@ export function ModelConnectionsSettings({ searchSpaceId }: { searchSpaceId: num

-
- {connections.map((connection) => ( - - ))} -
+ {connections.length > 0 ? ( +
+ +

Available Providers

+
+ {connections.map((connection) => ( + + ))} +
+
+ ) : null} diff --git a/surfsense_web/contracts/types/model-connections.types.ts b/surfsense_web/contracts/types/model-connections.types.ts index a34687d74..c75f4c90a 100644 --- a/surfsense_web/contracts/types/model-connections.types.ts +++ b/surfsense_web/contracts/types/model-connections.types.ts @@ -26,6 +26,7 @@ export const connectionRead = z.object({ id: z.number(), provider: z.string(), base_url: z.string().nullable().optional(), + api_key: z.string().nullable().optional(), extra: z.record(z.string(), z.any()).default({}), scope: z.union([connectionScopeEnum, z.string()]), search_space_id: z.number().nullable().optional(), @@ -73,6 +74,11 @@ export const modelUpdateRequest = z.object({ capabilities_override: z.record(z.string(), z.any()).optional(), }); +export const modelsBulkUpdateRequest = z.object({ + model_ids: z.array(z.number()).min(1).max(1000), + enabled: z.boolean(), +}); + export const verifyConnectionResponse = z.object({ status: z.string(), ok: z.boolean(), @@ -107,6 +113,7 @@ export type ConnectionCreateRequest = z.infer; export type ConnectionUpdateRequest = z.infer; export type ModelCreateRequest = z.infer; export type ModelUpdateRequest = z.infer; +export type ModelsBulkUpdateRequest = z.infer; export type ModelRoles = z.infer; export type VerifyConnectionResponse = z.infer; export type ModelProviderRead = z.infer; diff --git a/surfsense_web/lib/apis/model-connections-api.service.ts b/surfsense_web/lib/apis/model-connections-api.service.ts index 12ad8e0d2..bd5aa1309 100644 --- a/surfsense_web/lib/apis/model-connections-api.service.ts +++ b/surfsense_web/lib/apis/model-connections-api.service.ts @@ -10,12 +10,14 @@ import { type ModelProviderRead, type ModelRead, type ModelRoles, + type ModelsBulkUpdateRequest, type ModelUpdateRequest, modelCreateRequest, - modelProviderListResponse, modelListResponse, + modelProviderListResponse, modelRead, modelRoles, + modelsBulkUpdateRequest, modelUpdateRequest, type VerifyConnectionResponse, verifyConnectionResponse, @@ -97,6 +99,25 @@ class ModelConnectionsApiService { }); }; + bulkUpdateModels = async ( + connectionId: number, + request: ModelsBulkUpdateRequest + ): Promise => { + const parsed = modelsBulkUpdateRequest.safeParse(request); + if (!parsed.success) { + throw new ValidationError(parsed.error.issues.map((issue) => issue.message).join(", ")); + } + return baseApiService.request( + `/api/v1/model-connections/${connectionId}/models`, + modelListResponse, + { + method: "PATCH", + headers: { "Content-Type": "application/json" }, + body: parsed.data, + } + ); + }; + testModel = async (id: number): Promise => { return baseApiService.post(`/api/v1/models/${id}/test`, verifyConnectionResponse); }; From 87be162d78446211cc5511dbdc424583f4507357 Mon Sep 17 00:00:00 2001 From: CREDO23 Date: Fri, 12 Jun 2026 07:38:38 +0200 Subject: [PATCH 030/212] feat(podcast): curated common languages data --- .../app/podcasts/voices/data/languages.py | 33 +++++++++++++++++++ 1 file changed, 33 insertions(+) create mode 100644 surfsense_backend/app/podcasts/voices/data/languages.py diff --git a/surfsense_backend/app/podcasts/voices/data/languages.py b/surfsense_backend/app/podcasts/voices/data/languages.py new file mode 100644 index 000000000..c00fd7f05 --- /dev/null +++ b/surfsense_backend/app/podcasts/voices/data/languages.py @@ -0,0 +1,33 @@ +"""Curated languages offered when a roster has wildcard (any-language) voices. + +OpenAI-style multilingual voices speak whatever language the text is in, so +there is no provider list to enumerate. This is the set the brief form offers +up front for such providers; it is an offering, not a limit — the API flags +``allows_custom`` so users can enter any BCP-47 tag beyond it. +""" + +from __future__ import annotations + +COMMON_LANGUAGES: tuple[str, ...] = ( + "ar", + "bn", + "de", + "en", + "es", + "fr", + "hi", + "id", + "it", + "ja", + "ko", + "nl", + "pl", + "pt", + "ru", + "sw", + "th", + "tr", + "uk", + "vi", + "zh", +) From c8ee74b123a0a925f17b19320f7a60b3153b98ae Mon Sep 17 00:00:00 2001 From: CREDO23 Date: Fri, 12 Jun 2026 07:38:38 +0200 Subject: [PATCH 031/212] feat(podcast): offerable languages on the voice catalog --- .../app/podcasts/voices/catalog.py | 31 ++++++++++++++++++- 1 file changed, 30 insertions(+), 1 deletion(-) diff --git a/surfsense_backend/app/podcasts/voices/catalog.py b/surfsense_backend/app/podcasts/voices/catalog.py index 28914e742..101f8f9d0 100644 --- a/surfsense_backend/app/podcasts/voices/catalog.py +++ b/surfsense_backend/app/podcasts/voices/catalog.py @@ -9,11 +9,26 @@ provider-native reference. from __future__ import annotations from collections.abc import Iterable +from dataclasses import dataclass from functools import lru_cache from .data import AZURE_VOICES, KOKORO_VOICES, OPENAI_VOICES, VERTEX_VOICES +from .data.languages import COMMON_LANGUAGES from .provider import TtsProvider -from .voice import CatalogVoice +from .voice import ANY_LANGUAGE, CatalogVoice + + +@dataclass(frozen=True, slots=True) +class LanguageOffering: + """The languages a provider's roster can offer the brief form. + + ``allows_custom`` is true when the roster has wildcard voices: the listed + languages are then a curated starting point, not a limit, and any BCP-47 + tag may be entered. + """ + + languages: list[str] + allows_custom: bool class VoiceCatalog: @@ -46,6 +61,20 @@ class VoiceCatalog: """Whether ``provider`` has at least one voice for ``language``.""" return any(v.speaks(language) for v in self.for_provider(provider)) + def offerable_languages(self, provider: TtsProvider) -> LanguageOffering: + """The languages ``provider`` can offer up front. + + Language-bound voices contribute their concrete tags; wildcard voices + cannot enumerate languages, so their presence merges in the curated + common list and opens free entry. + """ + voices = self.for_provider(provider) + tags = {v.language for v in voices if v.language != ANY_LANGUAGE} + has_wildcard = any(v.language == ANY_LANGUAGE for v in voices) + if has_wildcard: + tags.update(COMMON_LANGUAGES) + return LanguageOffering(languages=sorted(tags), allows_custom=has_wildcard) + @lru_cache(maxsize=1) def get_voice_catalog() -> VoiceCatalog: From f3d253ae7773e7b077734ac6940ff9d0cdcfbf88 Mon Sep 17 00:00:00 2001 From: CREDO23 Date: Fri, 12 Jun 2026 07:38:38 +0200 Subject: [PATCH 032/212] feat(podcast): export LanguageOffering --- surfsense_backend/app/podcasts/voices/__init__.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/surfsense_backend/app/podcasts/voices/__init__.py b/surfsense_backend/app/podcasts/voices/__init__.py index ab1f8bbbf..97874a655 100644 --- a/surfsense_backend/app/podcasts/voices/__init__.py +++ b/surfsense_backend/app/podcasts/voices/__init__.py @@ -6,7 +6,7 @@ configured provider via :func:`provider_from_service`. from __future__ import annotations -from .catalog import VoiceCatalog, get_voice_catalog +from .catalog import LanguageOffering, VoiceCatalog, get_voice_catalog from .preview import render_voice_preview from .provider import TtsProvider, provider_from_service from .voice import ANY_LANGUAGE, CatalogVoice, VoiceGender @@ -14,6 +14,7 @@ from .voice import ANY_LANGUAGE, CatalogVoice, VoiceGender __all__ = [ "ANY_LANGUAGE", "CatalogVoice", + "LanguageOffering", "TtsProvider", "VoiceCatalog", "VoiceGender", From fe4d69f478e407c4b2faf6d4c0851a02af73d5fa Mon Sep 17 00:00:00 2001 From: CREDO23 Date: Fri, 12 Jun 2026 07:38:38 +0200 Subject: [PATCH 033/212] feat(podcast): LanguageOptions read model --- surfsense_backend/app/podcasts/api/schemas.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/surfsense_backend/app/podcasts/api/schemas.py b/surfsense_backend/app/podcasts/api/schemas.py index 7f1f8cc7c..c412e372f 100644 --- a/surfsense_backend/app/podcasts/api/schemas.py +++ b/surfsense_backend/app/podcasts/api/schemas.py @@ -51,6 +51,17 @@ class VoiceOption(BaseModel): gender: str +class LanguageOptions(BaseModel): + """The languages the brief editor may offer for the active provider. + + When ``allows_custom`` is true the list is a curated starting point and + the editor accepts any BCP-47 tag beyond it. + """ + + languages: list[str] + allows_custom: bool + + class PodcastSummary(BaseModel): """Lightweight list item.""" From 1ee38fc9ab930d71806afa49d04b77ff601c9ccb Mon Sep 17 00:00:00 2001 From: CREDO23 Date: Fri, 12 Jun 2026 07:38:38 +0200 Subject: [PATCH 034/212] feat(podcast): GET /podcasts/languages --- surfsense_backend/app/podcasts/api/routes.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/surfsense_backend/app/podcasts/api/routes.py b/surfsense_backend/app/podcasts/api/routes.py index 0a9a8e659..593ec6990 100644 --- a/surfsense_backend/app/podcasts/api/routes.py +++ b/surfsense_backend/app/podcasts/api/routes.py @@ -45,6 +45,7 @@ from app.utils.rbac import check_permission from .schemas import ( CreatePodcastRequest, + LanguageOptions, PodcastDetail, PodcastSummary, UpdateSpecRequest, @@ -112,6 +113,20 @@ async def list_voices(language: str | None = None): ] +@router.get("/podcasts/languages", response_model=LanguageOptions) +async def list_languages(): + """Languages the active TTS provider can offer the brief editor.""" + if not app_config.TTS_SERVICE: + raise HTTPException(status_code=503, detail="No TTS provider configured") + + provider = provider_from_service(app_config.TTS_SERVICE) + offering = get_voice_catalog().offerable_languages(provider) + return LanguageOptions( + languages=offering.languages, + allows_custom=offering.allows_custom, + ) + + @router.get("/podcasts/voices/{voice_id}/preview") async def preview_voice( voice_id: str, From a19b7dd8e0ca3324e9dee974d650be3856bb96a0 Mon Sep 17 00:00:00 2001 From: CREDO23 Date: Fri, 12 Jun 2026 07:38:38 +0200 Subject: [PATCH 035/212] test(podcast): offerable languages catalog rules --- .../tests/unit/podcasts/test_voice_catalog.py | 53 +++++++++++++++++++ 1 file changed, 53 insertions(+) diff --git a/surfsense_backend/tests/unit/podcasts/test_voice_catalog.py b/surfsense_backend/tests/unit/podcasts/test_voice_catalog.py index d94c85922..e7d4c8d2b 100644 --- a/surfsense_backend/tests/unit/podcasts/test_voice_catalog.py +++ b/surfsense_backend/tests/unit/podcasts/test_voice_catalog.py @@ -73,6 +73,59 @@ def test_supports_language_reports_availability(): assert not catalog.supports_language(TtsProvider.KOKORO, "de") +def test_offerable_languages_for_a_concrete_roster_are_its_tags_only(): + """A provider whose voices are language-bound offers exactly those tags.""" + catalog = VoiceCatalog( + [ + _voice("k1", language="en-US"), + _voice("k2", language="fr"), + _voice("k3", language="fr"), + ] + ) + + offering = catalog.offerable_languages(TtsProvider.KOKORO) + + assert offering.languages == ["en-US", "fr"] + assert offering.allows_custom is False + + +def test_a_wildcard_roster_offers_the_curated_languages_and_custom_entry(): + """Voices that speak anything can't enumerate languages themselves, so the + catalog offers the curated common list and invites free entry.""" + catalog = VoiceCatalog( + [_voice("o1", provider=TtsProvider.OPENAI, language=ANY_LANGUAGE)] + ) + + offering = catalog.offerable_languages(TtsProvider.OPENAI) + + assert {"en", "fr", "sw", "hi", "zh"} <= set(offering.languages) + assert offering.allows_custom is True + + +def test_a_mixed_roster_offers_the_union_of_concrete_and_curated(): + catalog = VoiceCatalog( + [ + _voice("v1", provider=TtsProvider.VERTEX_AI, language="en-GB"), + _voice("v2", provider=TtsProvider.VERTEX_AI, language=ANY_LANGUAGE), + ] + ) + + offering = catalog.offerable_languages(TtsProvider.VERTEX_AI) + + assert "en-GB" in offering.languages + assert "fr" in offering.languages + assert offering.allows_custom is True + + +def test_a_provider_with_no_voices_offers_nothing(): + catalog = VoiceCatalog([_voice("k1")]) + + offering = catalog.offerable_languages(TtsProvider.OPENAI) + + assert offering.languages == [] + assert offering.allows_custom is False + + def test_get_raises_for_an_unknown_voice(): catalog = VoiceCatalog([_voice("k1")]) with pytest.raises(KeyError): From 402ae6befec50a7cd5c0d323d61f6b2749bf33a1 Mon Sep 17 00:00:00 2001 From: CREDO23 Date: Fri, 12 Jun 2026 07:38:38 +0200 Subject: [PATCH 036/212] test(podcast): languages endpoint --- .../tests/integration/podcasts/test_voices.py | 20 +++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/surfsense_backend/tests/integration/podcasts/test_voices.py b/surfsense_backend/tests/integration/podcasts/test_voices.py index 688ddad56..fd41bfd4e 100644 --- a/surfsense_backend/tests/integration/podcasts/test_voices.py +++ b/surfsense_backend/tests/integration/podcasts/test_voices.py @@ -29,3 +29,23 @@ async def test_voices_503_when_no_tts_configured(client, monkeypatch): resp = await client.get(f"{BASE}/voices") assert resp.status_code == 503 + + +async def test_languages_returns_the_active_providers_offering(client): + """The brief form renders exactly what the backend offers — for a wildcard + provider (openai/tts-1) that is the curated list plus free entry.""" + resp = await client.get(f"{BASE}/languages") + + assert resp.status_code == 200 + offering = resp.json() + assert "en" in offering["languages"] + assert "fr" in offering["languages"] + assert offering["allows_custom"] is True + + +async def test_languages_503_when_no_tts_configured(client, monkeypatch): + monkeypatch.setattr(app_config, "TTS_SERVICE", "") + + resp = await client.get(f"{BASE}/languages") + + assert resp.status_code == 503 From 0c7e5dee8bd98e814ed944638924e504dca92696 Mon Sep 17 00:00:00 2001 From: CREDO23 Date: Fri, 12 Jun 2026 07:38:38 +0200 Subject: [PATCH 037/212] test(podcast): align quota error kwargs with wallet refactor --- .../tests/integration/podcasts/test_draft_task.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/surfsense_backend/tests/integration/podcasts/test_draft_task.py b/surfsense_backend/tests/integration/podcasts/test_draft_task.py index 7dadfc2f5..e9c9e4a9c 100644 --- a/surfsense_backend/tests/integration/podcasts/test_draft_task.py +++ b/surfsense_backend/tests/integration/podcasts/test_draft_task.py @@ -76,8 +76,7 @@ async def test_quota_denial_fails_the_podcast_without_a_transcript( async def _deny(**_kwargs): raise QuotaInsufficientError( usage_type="podcast_generation", - used_micros=5_000_000, - limit_micros=5_000_000, + balance_micros=0, remaining_micros=0, ) yield # pragma: no cover - unreachable, satisfies the CM protocol From 3cf76e82951de8bdf74d5e877bf96bb409add3cc Mon Sep 17 00:00:00 2001 From: CREDO23 Date: Fri, 12 Jun 2026 07:38:38 +0200 Subject: [PATCH 038/212] feat(podcast): languageOptions contract --- surfsense_web/contracts/types/podcast.types.ts | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/surfsense_web/contracts/types/podcast.types.ts b/surfsense_web/contracts/types/podcast.types.ts index e6332d5b2..627cc6f58 100644 --- a/surfsense_web/contracts/types/podcast.types.ts +++ b/surfsense_web/contracts/types/podcast.types.ts @@ -103,6 +103,15 @@ export const voiceOption = z.object({ }); export type VoiceOption = z.infer; +// The languages the backend offers for the active TTS provider. When +// `allows_custom` is true the list is a starting point and any BCP-47 tag +// may be entered. +export const languageOptions = z.object({ + languages: z.array(z.string()), + allows_custom: z.boolean(), +}); +export type LanguageOptions = z.infer; + export const updateSpecRequest = z.object({ spec: podcastSpec, expected_version: z.number().int().min(1), From 90cae46b5f2a129fc12c230e6f095a5c957da198 Mon Sep 17 00:00:00 2001 From: CREDO23 Date: Fri, 12 Jun 2026 07:38:38 +0200 Subject: [PATCH 039/212] feat(podcast): listLanguages API call --- surfsense_web/lib/apis/podcasts-api.service.ts | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/surfsense_web/lib/apis/podcasts-api.service.ts b/surfsense_web/lib/apis/podcasts-api.service.ts index bd7bb784e..2e13d63cc 100644 --- a/surfsense_web/lib/apis/podcasts-api.service.ts +++ b/surfsense_web/lib/apis/podcasts-api.service.ts @@ -1,5 +1,6 @@ import { z } from "zod"; import { + languageOptions, type PodcastSpec, podcastDetail, updateSpecRequest, @@ -60,6 +61,12 @@ class PodcastsApiService { return baseApiService.get(`${BASE}/voices${qs}`, voiceOptionList); }; + // The languages the active provider can offer; the brief form renders + // exactly this list and only opens free entry when the backend allows it. + listLanguages = async () => { + return baseApiService.get(`${BASE}/languages`, languageOptions); + }; + // A short audio sample of a voice, cached server-side per voice. previewVoice = async (voiceId: string) => { return baseApiService.getBlob(`${BASE}/voices/${encodeURIComponent(voiceId)}/preview`); From 8dd174d304d5125aad95509bee6e036230dc24a1 Mon Sep 17 00:00:00 2001 From: CREDO23 Date: Fri, 12 Jun 2026 07:38:38 +0200 Subject: [PATCH 040/212] feat(podcast): backend-driven language picker with custom entry --- .../tool-ui/podcast/brief-review.tsx | 132 +++++++++++++++--- 1 file changed, 114 insertions(+), 18 deletions(-) diff --git a/surfsense_web/components/tool-ui/podcast/brief-review.tsx b/surfsense_web/components/tool-ui/podcast/brief-review.tsx index 3473b64d6..d662aebc2 100644 --- a/surfsense_web/components/tool-ui/podcast/brief-review.tsx +++ b/surfsense_web/components/tool-ui/podcast/brief-review.tsx @@ -1,11 +1,20 @@ "use client"; -import { Loader2, Plus, Trash2 } from "lucide-react"; +import { Check, ChevronDown, Loader2, Plus, Trash2 } from "lucide-react"; import { useEffect, useMemo, useState } from "react"; import { toast } from "sonner"; import { Button } from "@/components/ui/button"; +import { + Command, + CommandEmpty, + CommandGroup, + CommandInput, + CommandItem, + CommandList, +} from "@/components/ui/command"; import { Input } from "@/components/ui/input"; import { Label } from "@/components/ui/label"; +import { Popover, PopoverContent, PopoverTrigger } from "@/components/ui/popover"; import { Select, SelectContent, @@ -15,6 +24,7 @@ import { } from "@/components/ui/select"; import { Textarea } from "@/components/ui/textarea"; import { + type LanguageOptions, MAX_SPEAKERS, type PodcastSpec, type PodcastStyle, @@ -56,6 +66,7 @@ interface BriefReviewProps { export function BriefReview({ podcast, spec }: BriefReviewProps) { const [draft, setDraft] = useState(spec); const [voices, setVoices] = useState(null); + const [offering, setOffering] = useState(null); const [isSubmitting, setIsSubmitting] = useState(false); // A pushed spec change (saved edit or concurrent editor) resets the form to @@ -75,19 +86,26 @@ export function BriefReview({ podcast, spec }: BriefReviewProps) { .catch(() => { if (!cancelled) setVoices([]); }); + podcastsApiService + .listLanguages() + .then((options) => { + if (!cancelled) setOffering(options); + }) + .catch(() => { + if (!cancelled) setOffering({ languages: [], allows_custom: false }); + }); return () => { cancelled = true; }; }, []); + // The backend owns the offering; the draft's language stays listed even + // when it falls outside it (e.g. a custom tag entered earlier). const languages = useMemo(() => { - const tags = new Set(); - for (const voice of voices ?? []) { - if (voice.language !== ANY_LANGUAGE) tags.add(voice.language); - } + const tags = new Set(offering?.languages ?? []); tags.add(draft.language); return [...tags].sort(); - }, [voices, draft.language]); + }, [offering, draft.language]); const voicesForLanguage = useMemo( () => (voices ?? []).filter((voice) => speaks(voice, draft.language)), @@ -193,18 +211,22 @@ export function BriefReview({ podcast, spec }: BriefReviewProps) {
- + {offering?.allows_custom ? ( + + ) : ( + + )}
@@ -375,6 +397,80 @@ export function BriefReview({ podcast, spec }: BriefReviewProps) { ); } +/** A searchable language picker for providers whose voices speak anything: + * the offered list comes from the backend, and any BCP-47 tag may be typed + * when none of them fits. */ +function LanguageCombobox({ + value, + languages, + onSelect, +}: { + value: string; + languages: string[]; + onSelect: (language: string) => void; +}) { + const [open, setOpen] = useState(false); + const [query, setQuery] = useState(""); + + const pick = (tag: string) => { + onSelect(tag); + setOpen(false); + setQuery(""); + }; + + const customTag = query.trim(); + const isNewTag = + customTag.length > 0 && !languages.some((tag) => tag.toLowerCase() === customTag.toLowerCase()); + + return ( + + + + + + + + + No matching language. + + {languages.map((tag) => ( + pick(tag)} + > + + {languageLabel(tag)} + + ))} + {isNewTag ? ( + pick(customTag)}> + + Use “{customTag}” + + ) : null} + + + + + + ); +} + /** The current selection stays listed even when it no longer matches the * language filter, so the Select never renders an orphaned value. */ function voiceItems(candidates: VoiceOption[], selectedId: string): VoiceOption[] { From 24f824b597cb73eaddb5ccb189d62b518a5c6c62 Mon Sep 17 00:00:00 2001 From: CREDO23 Date: Fri, 12 Jun 2026 11:22:48 +0200 Subject: [PATCH 041/212] feat(etl-cache): add ParseKey cache identity value object --- .../etl_pipeline/cache/schemas/parse_key.py | 28 +++++++++++++++++++ 1 file changed, 28 insertions(+) create mode 100644 surfsense_backend/app/etl_pipeline/cache/schemas/parse_key.py diff --git a/surfsense_backend/app/etl_pipeline/cache/schemas/parse_key.py b/surfsense_backend/app/etl_pipeline/cache/schemas/parse_key.py new file mode 100644 index 000000000..65e7b08a5 --- /dev/null +++ b/surfsense_backend/app/etl_pipeline/cache/schemas/parse_key.py @@ -0,0 +1,28 @@ +"""Identity of a cacheable parse: equal keys yield identical markdown.""" + +from __future__ import annotations + +from dataclasses import dataclass + + +@dataclass(frozen=True, slots=True) +class ParseKey: + source_sha256: str + etl_service: str + mode: str + version: int + + @classmethod + def for_document( + cls, source_sha256: str, *, etl_service: str, mode: str, version: int + ) -> "ParseKey": + return cls( + source_sha256=source_sha256, + etl_service=etl_service, + mode=mode, + version=version, + ) + + @property + def object_suffix(self) -> str: + return f"{self.etl_service}.{self.mode}.v{self.version}.md" From 3c9ea0011d2dd0b9ccf75e2b46de279eef610d63 Mon Sep 17 00:00:00 2001 From: CREDO23 Date: Fri, 12 Jun 2026 11:22:48 +0200 Subject: [PATCH 042/212] feat(etl-cache): add EvictionCandidate value object --- .../cache/schemas/eviction_candidate.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) create mode 100644 surfsense_backend/app/etl_pipeline/cache/schemas/eviction_candidate.py diff --git a/surfsense_backend/app/etl_pipeline/cache/schemas/eviction_candidate.py b/surfsense_backend/app/etl_pipeline/cache/schemas/eviction_candidate.py new file mode 100644 index 000000000..13a903e7d --- /dev/null +++ b/surfsense_backend/app/etl_pipeline/cache/schemas/eviction_candidate.py @@ -0,0 +1,15 @@ +"""Row projection handed to the eviction policy.""" + +from __future__ import annotations + +from dataclasses import dataclass +from datetime import datetime + + +@dataclass(frozen=True, slots=True) +class EvictionCandidate: + id: int + storage_key: str + size_bytes: int + last_used_at: datetime + times_reused: int From b84debd99917ecce372a207d82774002748f37a2 Mon Sep 17 00:00:00 2001 From: CREDO23 Date: Fri, 12 Jun 2026 11:22:48 +0200 Subject: [PATCH 043/212] feat(etl-cache): expose cache schema value objects --- .../app/etl_pipeline/cache/schemas/__init__.py | 11 +++++++++++ 1 file changed, 11 insertions(+) create mode 100644 surfsense_backend/app/etl_pipeline/cache/schemas/__init__.py diff --git a/surfsense_backend/app/etl_pipeline/cache/schemas/__init__.py b/surfsense_backend/app/etl_pipeline/cache/schemas/__init__.py new file mode 100644 index 000000000..c88ac0c72 --- /dev/null +++ b/surfsense_backend/app/etl_pipeline/cache/schemas/__init__.py @@ -0,0 +1,11 @@ +"""Pure value objects for the parse cache.""" + +from __future__ import annotations + +from .eviction_candidate import EvictionCandidate +from .parse_key import ParseKey + +__all__ = [ + "EvictionCandidate", + "ParseKey", +] From 205a63b9bccc3b02c11199aa6ae625f0d48026f9 Mon Sep 17 00:00:00 2001 From: CREDO23 Date: Fri, 12 Jun 2026 11:22:48 +0200 Subject: [PATCH 044/212] feat(etl-cache): add EtlCacheSettings resolved from config --- .../app/etl_pipeline/cache/settings.py | 33 +++++++++++++++++++ 1 file changed, 33 insertions(+) create mode 100644 surfsense_backend/app/etl_pipeline/cache/settings.py diff --git a/surfsense_backend/app/etl_pipeline/cache/settings.py b/surfsense_backend/app/etl_pipeline/cache/settings.py new file mode 100644 index 000000000..5911ea222 --- /dev/null +++ b/surfsense_backend/app/etl_pipeline/cache/settings.py @@ -0,0 +1,33 @@ +"""Cache configuration resolved from the central ``Config``.""" + +from __future__ import annotations + +from dataclasses import dataclass + + +@dataclass(frozen=True) +class EtlCacheSettings: + enabled: bool + parser_version: int + ttl_days: int + max_total_bytes: int + eviction_batch: int + # None for any storage_* field means: reuse the main file_storage backend. + storage_backend: str | None + storage_container: str | None + storage_local_root: str | None + + +def load_etl_cache_settings() -> EtlCacheSettings: + from app.config import config + + return EtlCacheSettings( + enabled=config.ETL_CACHE_ENABLED, + parser_version=config.ETL_CACHE_PARSER_VERSION, + ttl_days=config.ETL_CACHE_TTL_DAYS, + max_total_bytes=config.ETL_CACHE_MAX_TOTAL_MB * 1024 * 1024, + eviction_batch=config.ETL_CACHE_EVICTION_BATCH, + storage_backend=config.ETL_CACHE_STORAGE_BACKEND or None, + storage_container=config.ETL_CACHE_STORAGE_CONTAINER or None, + storage_local_root=config.ETL_CACHE_STORAGE_LOCAL_PATH or None, + ) From c624235780fc965f8f648e85a3e2b68cfa456df4 Mon Sep 17 00:00:00 2001 From: CREDO23 Date: Fri, 12 Jun 2026 11:22:48 +0200 Subject: [PATCH 045/212] feat(etl-cache): add CachedParse table model --- .../etl_pipeline/cache/persistence/models.py | 49 +++++++++++++++++++ 1 file changed, 49 insertions(+) create mode 100644 surfsense_backend/app/etl_pipeline/cache/persistence/models.py diff --git a/surfsense_backend/app/etl_pipeline/cache/persistence/models.py b/surfsense_backend/app/etl_pipeline/cache/persistence/models.py new file mode 100644 index 000000000..bd20bdd12 --- /dev/null +++ b/surfsense_backend/app/etl_pipeline/cache/persistence/models.py @@ -0,0 +1,49 @@ +"""``etl_cache_parses``: one reusable parser result per (bytes + recipe).""" + +from __future__ import annotations + +from sqlalchemy import ( + BigInteger, + Column, + DateTime, + Index, + Integer, + String, + UniqueConstraint, +) + +from app.db import BaseModel, TimestampMixin + + +class CachedParse(BaseModel, TimestampMixin): + __tablename__ = "etl_cache_parses" + + # Key: raw bytes + the recipe that produced the markdown. + source_sha256 = Column(String(64), nullable=False) + etl_service = Column(String(32), nullable=False) + mode = Column(String(16), nullable=False) + parser_version = Column(Integer, nullable=False) + + # Where the markdown blob lives (kept out of the row to stay small). + storage_backend = Column(String(32), nullable=False) + storage_key = Column(String, nullable=False) + size_bytes = Column(BigInteger, nullable=False) + + # Payload needed to rebuild the EtlResult on a hit. + content_type = Column(String(32), nullable=False) + actual_pages = Column(Integer, nullable=False, default=0, server_default="0") + + # Drives eviction (popularity + recency). + times_reused = Column(BigInteger, nullable=False, default=0, server_default="0") + last_used_at = Column(DateTime(timezone=True), nullable=False) + + __table_args__ = ( + UniqueConstraint( + "source_sha256", + "etl_service", + "mode", + "parser_version", + name="uq_etl_cache_parses_key", + ), + Index("ix_etl_cache_parses_last_used_at", "last_used_at"), + ) From ea1012797911d677592172b646f73135d6bc7d98 Mon Sep 17 00:00:00 2001 From: CREDO23 Date: Fri, 12 Jun 2026 11:22:57 +0200 Subject: [PATCH 046/212] feat(etl-cache): add CachedParseRepository data access --- .../cache/persistence/repository.py | 121 ++++++++++++++++++ 1 file changed, 121 insertions(+) create mode 100644 surfsense_backend/app/etl_pipeline/cache/persistence/repository.py diff --git a/surfsense_backend/app/etl_pipeline/cache/persistence/repository.py b/surfsense_backend/app/etl_pipeline/cache/persistence/repository.py new file mode 100644 index 000000000..05f40eae5 --- /dev/null +++ b/surfsense_backend/app/etl_pipeline/cache/persistence/repository.py @@ -0,0 +1,121 @@ +"""CRUD and eviction selectors for ``etl_cache_parses`` (no business rules).""" + +from __future__ import annotations + +from datetime import UTC, datetime + +from sqlalchemy import delete, func, select, update +from sqlalchemy.dialects.postgresql import insert as pg_insert +from sqlalchemy.ext.asyncio import AsyncSession + +from app.etl_pipeline.cache.schemas import EvictionCandidate, ParseKey + +from .models import CachedParse + +_EVICTION_COLUMNS = ( + CachedParse.id, + CachedParse.storage_key, + CachedParse.size_bytes, + CachedParse.last_used_at, + CachedParse.times_reused, +) + + +def _as_eviction_candidate(row) -> EvictionCandidate: + return EvictionCandidate( + id=row.id, + storage_key=row.storage_key, + size_bytes=row.size_bytes, + last_used_at=row.last_used_at, + times_reused=row.times_reused, + ) + + +class CachedParseRepository: + def __init__(self, session: AsyncSession) -> None: + self._session = session + + async def get(self, key: ParseKey) -> CachedParse | None: + result = await self._session.execute( + select(CachedParse).where( + CachedParse.source_sha256 == key.source_sha256, + CachedParse.etl_service == key.etl_service, + CachedParse.mode == key.mode, + CachedParse.parser_version == key.version, + ) + ) + return result.scalars().first() + + async def insert( + self, + *, + key: ParseKey, + content_type: str, + actual_pages: int, + storage_backend: str, + storage_key: str, + size_bytes: int, + ) -> None: + # Concurrent writers parse identical bytes, so a lost race is harmless. + now = datetime.now(UTC) + await self._session.execute( + pg_insert(CachedParse) + .values( + source_sha256=key.source_sha256, + etl_service=key.etl_service, + mode=key.mode, + parser_version=key.version, + content_type=content_type, + actual_pages=actual_pages, + storage_backend=storage_backend, + storage_key=storage_key, + size_bytes=size_bytes, + times_reused=0, + last_used_at=now, + created_at=now, + ) + .on_conflict_do_nothing(constraint="uq_etl_cache_parses_key") + ) + await self._session.commit() + + async def mark_used(self, row_id: int) -> None: + await self._session.execute( + update(CachedParse) + .where(CachedParse.id == row_id) + .values( + times_reused=CachedParse.times_reused + 1, + last_used_at=datetime.now(UTC), + ) + ) + await self._session.commit() + + async def total_size_bytes(self) -> int: + result = await self._session.execute( + select(func.coalesce(func.sum(CachedParse.size_bytes), 0)) + ) + return int(result.scalar() or 0) + + async def select_expired( + self, *, cutoff: datetime, limit: int + ) -> list[EvictionCandidate]: + result = await self._session.execute( + select(*_EVICTION_COLUMNS) + .where(CachedParse.last_used_at < cutoff) + .order_by(CachedParse.last_used_at.asc()) + .limit(limit) + ) + return [_as_eviction_candidate(row) for row in result] + + async def select_coldest(self, *, limit: int) -> list[EvictionCandidate]: + result = await self._session.execute( + select(*_EVICTION_COLUMNS) + .order_by(CachedParse.times_reused.asc(), CachedParse.last_used_at.asc()) + .limit(limit) + ) + return [_as_eviction_candidate(row) for row in result] + + async def delete_by_ids(self, ids: list[int]) -> None: + if not ids: + return + await self._session.execute(delete(CachedParse).where(CachedParse.id.in_(ids))) + await self._session.commit() From 8d3238bcd1fb770b7a624a60b90746b64dda2b0a Mon Sep 17 00:00:00 2001 From: CREDO23 Date: Fri, 12 Jun 2026 11:22:57 +0200 Subject: [PATCH 047/212] feat(etl-cache): expose cache persistence layer --- .../app/etl_pipeline/cache/persistence/__init__.py | 11 +++++++++++ 1 file changed, 11 insertions(+) create mode 100644 surfsense_backend/app/etl_pipeline/cache/persistence/__init__.py diff --git a/surfsense_backend/app/etl_pipeline/cache/persistence/__init__.py b/surfsense_backend/app/etl_pipeline/cache/persistence/__init__.py new file mode 100644 index 000000000..666e4cfa8 --- /dev/null +++ b/surfsense_backend/app/etl_pipeline/cache/persistence/__init__.py @@ -0,0 +1,11 @@ +"""Database access for cached parse rows.""" + +from __future__ import annotations + +from .models import CachedParse +from .repository import CachedParseRepository + +__all__ = [ + "CachedParse", + "CachedParseRepository", +] From d9b1b491e921468c90f7d9e2e26174954331250e Mon Sep 17 00:00:00 2001 From: CREDO23 Date: Fri, 12 Jun 2026 11:22:57 +0200 Subject: [PATCH 048/212] feat(etl-cache): add cache blob object-key builder --- .../app/etl_pipeline/cache/storage/object_keys.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) create mode 100644 surfsense_backend/app/etl_pipeline/cache/storage/object_keys.py diff --git a/surfsense_backend/app/etl_pipeline/cache/storage/object_keys.py b/surfsense_backend/app/etl_pipeline/cache/storage/object_keys.py new file mode 100644 index 000000000..7b89c3f92 --- /dev/null +++ b/surfsense_backend/app/etl_pipeline/cache/storage/object_keys.py @@ -0,0 +1,12 @@ +"""Object keys for cached markdown, namespaced under a dedicated prefix.""" + +from __future__ import annotations + +from app.etl_pipeline.cache.schemas import ParseKey + +CACHE_PREFIX = "etl_cache" + + +def build_parse_object_key(key: ParseKey) -> str: + # Content-addressed: identical bytes + recipe always map to the same key. + return f"{CACHE_PREFIX}/{key.source_sha256}/{key.object_suffix}" From 217d040e9e984586d5415e47458c7670f8a8dad2 Mon Sep 17 00:00:00 2001 From: CREDO23 Date: Fri, 12 Jun 2026 11:22:57 +0200 Subject: [PATCH 049/212] feat(etl-cache): resolve cache blob storage backend --- .../app/etl_pipeline/cache/storage/backend.py | 46 +++++++++++++++++++ 1 file changed, 46 insertions(+) create mode 100644 surfsense_backend/app/etl_pipeline/cache/storage/backend.py diff --git a/surfsense_backend/app/etl_pipeline/cache/storage/backend.py b/surfsense_backend/app/etl_pipeline/cache/storage/backend.py new file mode 100644 index 000000000..ac7501984 --- /dev/null +++ b/surfsense_backend/app/etl_pipeline/cache/storage/backend.py @@ -0,0 +1,46 @@ +"""Resolve the storage backend for cache blobs: shared main store or a dedicated one.""" + +from __future__ import annotations + +from functools import lru_cache + +from app.file_storage.backends.base import StorageBackend + + +@lru_cache(maxsize=1) +def resolve_cache_backend() -> StorageBackend: + from app.etl_pipeline.cache.settings import load_etl_cache_settings + + settings = load_etl_cache_settings() + + if not settings.storage_backend: + from app.file_storage.factory import get_storage_backend + + return get_storage_backend() + + backend = settings.storage_backend.strip().lower() + + if backend == "azure": + from app.config import config + + if not settings.storage_container: + raise ValueError("ETL_CACHE_STORAGE_CONTAINER is required for azure cache.") + if not config.AZURE_STORAGE_CONNECTION_STRING: + raise ValueError( + "AZURE_STORAGE_CONNECTION_STRING is required for azure cache." + ) + from app.file_storage.backends.azure import AzureBlobBackend + + return AzureBlobBackend( + connection_string=config.AZURE_STORAGE_CONNECTION_STRING, + container=settings.storage_container, + ) + + if backend == "local": + if not settings.storage_local_root: + raise ValueError("ETL_CACHE_STORAGE_LOCAL_PATH is required for local cache.") + from app.file_storage.backends.local import LocalFileBackend + + return LocalFileBackend(settings.storage_local_root) + + raise ValueError(f"Unknown ETL_CACHE_STORAGE_BACKEND: {settings.storage_backend!r}") From a6f2457c7ca73b4d51faf757f650343b52e04366 Mon Sep 17 00:00:00 2001 From: CREDO23 Date: Fri, 12 Jun 2026 11:22:57 +0200 Subject: [PATCH 050/212] feat(etl-cache): add MarkdownCacheStore for cache blobs --- .../cache/storage/markdown_store.py | 35 +++++++++++++++++++ 1 file changed, 35 insertions(+) create mode 100644 surfsense_backend/app/etl_pipeline/cache/storage/markdown_store.py diff --git a/surfsense_backend/app/etl_pipeline/cache/storage/markdown_store.py b/surfsense_backend/app/etl_pipeline/cache/storage/markdown_store.py new file mode 100644 index 000000000..189f3508b --- /dev/null +++ b/surfsense_backend/app/etl_pipeline/cache/storage/markdown_store.py @@ -0,0 +1,35 @@ +"""Read and write cached markdown blobs through the resolved backend.""" + +from __future__ import annotations + +from app.etl_pipeline.cache.schemas import ParseKey +from app.etl_pipeline.cache.storage.backend import resolve_cache_backend +from app.etl_pipeline.cache.storage.object_keys import build_parse_object_key + +_MARKDOWN_CONTENT_TYPE = "text/markdown; charset=utf-8" + + +class MarkdownCacheStore: + def __init__(self) -> None: + self._backend = resolve_cache_backend() + + @property + def backend_name(self) -> str: + return self._backend.backend_name + + async def save(self, key: ParseKey, markdown: str) -> str: + """Persist the markdown and return its storage key for the index row.""" + storage_key = build_parse_object_key(key) + await self._backend.put( + storage_key, + markdown.encode("utf-8"), + content_type=_MARKDOWN_CONTENT_TYPE, + ) + return storage_key + + async def load(self, storage_key: str) -> str: + chunks = [chunk async for chunk in self._backend.open_stream(storage_key)] + return b"".join(chunks).decode("utf-8") + + async def delete(self, storage_key: str) -> None: + await self._backend.delete(storage_key) From 87fdb37fa3f19b2d618968a8647f0c7d75c09bd6 Mon Sep 17 00:00:00 2001 From: CREDO23 Date: Fri, 12 Jun 2026 11:23:40 +0200 Subject: [PATCH 051/212] feat(etl-cache): expose storage layer --- .../app/etl_pipeline/cache/storage/__init__.py | 9 +++++++++ 1 file changed, 9 insertions(+) create mode 100644 surfsense_backend/app/etl_pipeline/cache/storage/__init__.py diff --git a/surfsense_backend/app/etl_pipeline/cache/storage/__init__.py b/surfsense_backend/app/etl_pipeline/cache/storage/__init__.py new file mode 100644 index 000000000..bed39c510 --- /dev/null +++ b/surfsense_backend/app/etl_pipeline/cache/storage/__init__.py @@ -0,0 +1,9 @@ +"""Blob storage for cached parse markdown.""" + +from __future__ import annotations + +from .markdown_store import MarkdownCacheStore + +__all__ = [ + "MarkdownCacheStore", +] From 41dea96af4bc76861825c433967b9788a62613c8 Mon Sep 17 00:00:00 2001 From: CREDO23 Date: Fri, 12 Jun 2026 11:23:40 +0200 Subject: [PATCH 052/212] feat(etl-cache): add EtlCacheService --- .../app/etl_pipeline/cache/service.py | 53 +++++++++++++++++++ 1 file changed, 53 insertions(+) create mode 100644 surfsense_backend/app/etl_pipeline/cache/service.py diff --git a/surfsense_backend/app/etl_pipeline/cache/service.py b/surfsense_backend/app/etl_pipeline/cache/service.py new file mode 100644 index 000000000..49398faf8 --- /dev/null +++ b/surfsense_backend/app/etl_pipeline/cache/service.py @@ -0,0 +1,53 @@ +"""Recall and remember parser output, coordinating the index and blob store.""" + +from __future__ import annotations + +import logging + +from sqlalchemy.ext.asyncio import AsyncSession + +from app.etl_pipeline.cache.persistence import CachedParseRepository +from app.etl_pipeline.cache.schemas import ParseKey +from app.etl_pipeline.cache.storage import MarkdownCacheStore +from app.etl_pipeline.etl_document import EtlResult + +logger = logging.getLogger(__name__) + + +class EtlCacheService: + def __init__(self, session: AsyncSession) -> None: + self._index = CachedParseRepository(session) + self._store = MarkdownCacheStore() + + async def recall(self, key: ParseKey) -> EtlResult | None: + """Return the cached result, or None on a miss.""" + row = await self._index.get(key) + if row is None: + return None + + try: + markdown = await self._store.load(row.storage_key) + except Exception: + # Index points at a blob that is gone; treat as a miss and re-parse. + logger.warning("Cache blob missing: %s", row.storage_key, exc_info=True) + return None + + await self._index.mark_used(row.id) + return EtlResult( + markdown_content=markdown, + etl_service=row.etl_service, + actual_pages=row.actual_pages, + content_type=row.content_type, + ) + + async def remember(self, key: ParseKey, result: EtlResult) -> None: + """Store a freshly parsed result for future reuse.""" + storage_key = await self._store.save(key, result.markdown_content) + await self._index.insert( + key=key, + content_type=result.content_type, + actual_pages=result.actual_pages, + storage_backend=self._store.backend_name, + storage_key=storage_key, + size_bytes=len(result.markdown_content.encode("utf-8")), + ) From 758da06c4f6d41d484d475012544a4c3b6558023 Mon Sep 17 00:00:00 2001 From: CREDO23 Date: Fri, 12 Jun 2026 11:23:40 +0200 Subject: [PATCH 053/212] feat(etl-cache): add extract_with_cache --- .../etl_pipeline/cache/cached_extraction.py | 82 +++++++++++++++++++ 1 file changed, 82 insertions(+) create mode 100644 surfsense_backend/app/etl_pipeline/cache/cached_extraction.py diff --git a/surfsense_backend/app/etl_pipeline/cache/cached_extraction.py b/surfsense_backend/app/etl_pipeline/cache/cached_extraction.py new file mode 100644 index 000000000..5348c5f4b --- /dev/null +++ b/surfsense_backend/app/etl_pipeline/cache/cached_extraction.py @@ -0,0 +1,82 @@ +"""Entry point: serve ETL parses from cache, parsing only on a miss.""" + +from __future__ import annotations + +import asyncio +import hashlib +import logging + +from app.config import config +from app.etl_pipeline.cache.schemas import ParseKey +from app.etl_pipeline.cache.service import EtlCacheService +from app.etl_pipeline.cache.settings import load_etl_cache_settings +from app.etl_pipeline.etl_document import EtlRequest, EtlResult +from app.etl_pipeline.etl_pipeline_service import EtlPipelineService +from app.etl_pipeline.file_classifier import FileCategory, classify_file + +logger = logging.getLogger(__name__) + +_HASH_CHUNK = 1024 * 1024 + + +async def extract_with_cache( + request: EtlRequest, *, vision_llm=None +) -> EtlResult: + """Drop-in for ``EtlPipelineService.extract`` that reuses prior parser output.""" + settings = load_etl_cache_settings() + + # Vision-LLM appends model-generated content not captured by the key, so its + # output must not be shared with plain parses (and vice versa): bypass cache. + cacheable = ( + settings.enabled + and vision_llm is None + and bool(config.ETL_SERVICE) + and classify_file(request.filename) == FileCategory.DOCUMENT + ) + if not cacheable: + return await EtlPipelineService(vision_llm=vision_llm).extract(request) + + key = ParseKey.for_document( + await asyncio.to_thread(_hash_file, request.file_path), + etl_service=config.ETL_SERVICE, + mode=request.processing_mode.value, + version=settings.parser_version, + ) + + cached_result = await _recall(key) + if cached_result is not None: + return cached_result + + result = await EtlPipelineService(vision_llm=vision_llm).extract(request) + await _remember(key, result) + return result + + +async def _recall(key: ParseKey) -> EtlResult | None: + # Caching is best-effort: any failure falls through to a normal parse. + try: + from app.tasks.celery_tasks import get_celery_session_maker + + async with get_celery_session_maker()() as session: + return await EtlCacheService(session).recall(key) + except Exception: + logger.warning("ETL cache recall failed; parsing fresh", exc_info=True) + return None + + +async def _remember(key: ParseKey, result: EtlResult) -> None: + try: + from app.tasks.celery_tasks import get_celery_session_maker + + async with get_celery_session_maker()() as session: + await EtlCacheService(session).remember(key, result) + except Exception: + logger.warning("ETL cache write failed; result not cached", exc_info=True) + + +def _hash_file(path: str) -> str: + digest = hashlib.sha256() + with open(path, "rb") as handle: + for chunk in iter(lambda: handle.read(_HASH_CHUNK), b""): + digest.update(chunk) + return digest.hexdigest() From 7ad39fd995d61d03448b3feb43c4df8da8af2460 Mon Sep 17 00:00:00 2001 From: CREDO23 Date: Fri, 12 Jun 2026 11:23:40 +0200 Subject: [PATCH 054/212] feat(etl-cache): add eviction policy --- .../app/etl_pipeline/cache/eviction/policy.py | 28 +++++++++++++++++++ 1 file changed, 28 insertions(+) create mode 100644 surfsense_backend/app/etl_pipeline/cache/eviction/policy.py diff --git a/surfsense_backend/app/etl_pipeline/cache/eviction/policy.py b/surfsense_backend/app/etl_pipeline/cache/eviction/policy.py new file mode 100644 index 000000000..5a80752d6 --- /dev/null +++ b/surfsense_backend/app/etl_pipeline/cache/eviction/policy.py @@ -0,0 +1,28 @@ +"""Pure selection rules for which cached entries to drop.""" + +from __future__ import annotations + +from collections.abc import Iterable + +from app.etl_pipeline.cache.schemas import EvictionCandidate + + +def select_over_budget( + coldest_first: Iterable[EvictionCandidate], + *, + current_total_bytes: int, + max_total_bytes: int, +) -> list[EvictionCandidate]: + """Pick coldest entries until the footprint drops under the budget.""" + bytes_to_free = current_total_bytes - max_total_bytes + if bytes_to_free <= 0: + return [] + + chosen: list[EvictionCandidate] = [] + bytes_freed = 0 + for candidate in coldest_first: + if bytes_freed >= bytes_to_free: + break + chosen.append(candidate) + bytes_freed += candidate.size_bytes + return chosen From 324ba141a60648dc012f9edfd53be0eafaa9e2b4 Mon Sep 17 00:00:00 2001 From: CREDO23 Date: Fri, 12 Jun 2026 11:23:40 +0200 Subject: [PATCH 055/212] feat(etl-cache): add eviction task and public API --- .../app/etl_pipeline/cache/__init__.py | 11 ++++ .../etl_pipeline/cache/eviction/__init__.py | 9 +++ .../app/etl_pipeline/cache/eviction/task.py | 62 +++++++++++++++++++ 3 files changed, 82 insertions(+) create mode 100644 surfsense_backend/app/etl_pipeline/cache/__init__.py create mode 100644 surfsense_backend/app/etl_pipeline/cache/eviction/__init__.py create mode 100644 surfsense_backend/app/etl_pipeline/cache/eviction/task.py diff --git a/surfsense_backend/app/etl_pipeline/cache/__init__.py b/surfsense_backend/app/etl_pipeline/cache/__init__.py new file mode 100644 index 000000000..3f4585778 --- /dev/null +++ b/surfsense_backend/app/etl_pipeline/cache/__init__.py @@ -0,0 +1,11 @@ +"""Content-addressed reuse of expensive ETL parser output across workspaces.""" + +from __future__ import annotations + +from app.etl_pipeline.cache.cached_extraction import extract_with_cache +from app.etl_pipeline.cache.service import EtlCacheService + +__all__ = [ + "EtlCacheService", + "extract_with_cache", +] diff --git a/surfsense_backend/app/etl_pipeline/cache/eviction/__init__.py b/surfsense_backend/app/etl_pipeline/cache/eviction/__init__.py new file mode 100644 index 000000000..f47b9c4e0 --- /dev/null +++ b/surfsense_backend/app/etl_pipeline/cache/eviction/__init__.py @@ -0,0 +1,9 @@ +"""Background pruning of the parse cache by age and size budget.""" + +from __future__ import annotations + +from .task import evict_etl_cache_task + +__all__ = [ + "evict_etl_cache_task", +] diff --git a/surfsense_backend/app/etl_pipeline/cache/eviction/task.py b/surfsense_backend/app/etl_pipeline/cache/eviction/task.py new file mode 100644 index 000000000..98841b139 --- /dev/null +++ b/surfsense_backend/app/etl_pipeline/cache/eviction/task.py @@ -0,0 +1,62 @@ +"""Celery task that prunes the parse cache by TTL, then by size budget.""" + +from __future__ import annotations + +import contextlib +import logging +from datetime import UTC, datetime, timedelta + +from app.celery_app import celery_app +from app.etl_pipeline.cache.eviction.policy import select_over_budget +from app.etl_pipeline.cache.persistence import CachedParseRepository +from app.etl_pipeline.cache.schemas import EvictionCandidate +from app.etl_pipeline.cache.settings import load_etl_cache_settings +from app.etl_pipeline.cache.storage import MarkdownCacheStore +from app.tasks.celery_tasks import get_celery_session_maker, run_async_celery_task + +logger = logging.getLogger(__name__) + + +@celery_app.task(name="evict_etl_cache") +def evict_etl_cache_task(): + return run_async_celery_task(_evict) + + +async def _evict() -> None: + """Expire stale entries, then shed the coldest overflow only if still over budget.""" + settings = load_etl_cache_settings() + if not settings.enabled: + return + + store = MarkdownCacheStore() + async with get_celery_session_maker()() as session: + index = CachedParseRepository(session) + + cutoff = datetime.now(UTC) - timedelta(days=settings.ttl_days) + expired = await index.select_expired(cutoff=cutoff, limit=settings.eviction_batch) + await _drop(index, store, expired) + + total = await index.total_size_bytes() + if total > settings.max_total_bytes: + coldest = await index.select_coldest(limit=settings.eviction_batch) + over_budget = select_over_budget( + coldest, + current_total_bytes=total, + max_total_bytes=settings.max_total_bytes, + ) + await _drop(index, store, over_budget) + + +async def _drop( + index: CachedParseRepository, + store: MarkdownCacheStore, + candidates: list[EvictionCandidate], +) -> None: + if not candidates: + return + for candidate in candidates: + # Drop the index row even if the blob delete fails (orphan blob is harmless). + with contextlib.suppress(Exception): + await store.delete(candidate.storage_key) + await index.delete_by_ids([candidate.id for candidate in candidates]) + logger.info("Evicted %d cached parses", len(candidates)) From 5c4eec26cc21bcaeb3002746fba49be89e04b35b Mon Sep 17 00:00:00 2001 From: CREDO23 Date: Fri, 12 Jun 2026 11:23:50 +0200 Subject: [PATCH 056/212] feat(config): add ETL_CACHE_* settings --- surfsense_backend/app/config/__init__.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/surfsense_backend/app/config/__init__.py b/surfsense_backend/app/config/__init__.py index bbaf3ac55..525fe160d 100644 --- a/surfsense_backend/app/config/__init__.py +++ b/surfsense_backend/app/config/__init__.py @@ -952,6 +952,18 @@ class Config: AZURE_DI_ENDPOINT = os.getenv("AZURE_DI_ENDPOINT") AZURE_DI_KEY = os.getenv("AZURE_DI_KEY") + # ETL parse cache: reuse parser output for identical bytes across workspaces. + ETL_CACHE_ENABLED = os.getenv("ETL_CACHE_ENABLED", "false").strip().lower() == "true" + # Bump to invalidate every cached entry after a parser/behaviour change. + ETL_CACHE_PARSER_VERSION = int(os.getenv("ETL_CACHE_PARSER_VERSION", "1")) + ETL_CACHE_TTL_DAYS = int(os.getenv("ETL_CACHE_TTL_DAYS", "90")) + ETL_CACHE_MAX_TOTAL_MB = int(os.getenv("ETL_CACHE_MAX_TOTAL_MB", "5120")) + ETL_CACHE_EVICTION_BATCH = int(os.getenv("ETL_CACHE_EVICTION_BATCH", "500")) + # Optional dedicated blob storage; unset reuses the main file_storage backend. + ETL_CACHE_STORAGE_BACKEND = os.getenv("ETL_CACHE_STORAGE_BACKEND") + ETL_CACHE_STORAGE_CONTAINER = os.getenv("ETL_CACHE_STORAGE_CONTAINER") + ETL_CACHE_STORAGE_LOCAL_PATH = os.getenv("ETL_CACHE_STORAGE_LOCAL_PATH") + # Proxy provider selection. Maps to a ProxyProvider implementation registered # in app/utils/proxy/registry.py. Add new vendors there and switch via this var. PROXY_PROVIDER = os.getenv("PROXY_PROVIDER", "anonymous_proxies") From 9f29a885b1aa97cc3937c5a94c5dd3413f9988e0 Mon Sep 17 00:00:00 2001 From: CREDO23 Date: Fri, 12 Jun 2026 11:23:50 +0200 Subject: [PATCH 057/212] feat(db): register CachedParse model --- surfsense_backend/app/db.py | 1 + 1 file changed, 1 insertion(+) diff --git a/surfsense_backend/app/db.py b/surfsense_backend/app/db.py index 2d672131b..97843d395 100644 --- a/surfsense_backend/app/db.py +++ b/surfsense_backend/app/db.py @@ -2864,6 +2864,7 @@ from app.automations.persistence import ( # noqa: E402, F401 AutomationRun, AutomationTrigger, ) +from app.etl_pipeline.cache.persistence.models import CachedParse # noqa: E402, F401 from app.file_storage.persistence import DocumentFile # noqa: E402, F401 from app.notifications.persistence import Notification # noqa: E402, F401 from app.podcasts.persistence import ( # noqa: E402, F401 From 1c05980ffbd2f89bb62c43d59886230ce931b659 Mon Sep 17 00:00:00 2001 From: CREDO23 Date: Fri, 12 Jun 2026 11:23:50 +0200 Subject: [PATCH 058/212] feat(celery): schedule etl cache eviction --- surfsense_backend/app/celery_app.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/surfsense_backend/app/celery_app.py b/surfsense_backend/app/celery_app.py index 5eebffd65..413522189 100644 --- a/surfsense_backend/app/celery_app.py +++ b/surfsense_backend/app/celery_app.py @@ -192,6 +192,7 @@ celery_app = Celery( "app.tasks.celery_tasks.stripe_reconciliation_task", "app.tasks.celery_tasks.auto_reload_task", "app.tasks.celery_tasks.gateway_tasks", + "app.etl_pipeline.cache.eviction.task", "app.automations.tasks.execute_run", "app.automations.triggers.builtin.schedule.selector", "app.automations.triggers.builtin.event.selector", @@ -306,6 +307,12 @@ celery_app.conf.beat_schedule = { "schedule": crontab(hour="3", minute="17"), "options": {"expires": 600}, }, + # Prune the ETL parse cache (TTL + size budget) once daily, off-peak. + "evict-etl-cache": { + "task": "evict_etl_cache", + "schedule": crontab(hour="4", minute="0"), + "options": {"expires": 600}, + }, # Fire due automation schedule triggers (Beat entry owned by the schedule # trigger; see app.automations.triggers.builtin.schedule.source). **SCHEDULE_BEAT_SCHEDULE, From 0dc2ccc003aed18aa17ed9fe5486bf330a503d42 Mon Sep 17 00:00:00 2001 From: CREDO23 Date: Fri, 12 Jun 2026 11:23:50 +0200 Subject: [PATCH 059/212] feat(tasks): route extraction through etl cache --- .../app/tasks/document_processors/file_processors.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/surfsense_backend/app/tasks/document_processors/file_processors.py b/surfsense_backend/app/tasks/document_processors/file_processors.py index a646b7aa6..0c3d30766 100644 --- a/surfsense_backend/app/tasks/document_processors/file_processors.py +++ b/surfsense_backend/app/tasks/document_processors/file_processors.py @@ -381,7 +381,6 @@ async def _extract_file_content( Tuple of (markdown_content, etl_service_name, billable_pages). """ from app.etl_pipeline.etl_document import EtlRequest, ProcessingMode - from app.etl_pipeline.etl_pipeline_service import EtlPipelineService from app.etl_pipeline.file_classifier import ( FileCategory, classify_file as etl_classify, @@ -432,13 +431,16 @@ async def _extract_file_content( vision_llm = await get_vision_llm(session, search_space_id) - result = await EtlPipelineService(vision_llm=vision_llm).extract( + from app.etl_pipeline.cache import extract_with_cache + + result = await extract_with_cache( EtlRequest( file_path=file_path, filename=filename, estimated_pages=estimated_pages, processing_mode=mode, - ) + ), + vision_llm=vision_llm, ) with contextlib.suppress(Exception): From d898716cf4c9baa62fd6767844a301187addbbe8 Mon Sep 17 00:00:00 2001 From: CREDO23 Date: Fri, 12 Jun 2026 11:23:50 +0200 Subject: [PATCH 060/212] feat(migration): add etl_cache_parses table --- .../versions/160_add_etl_cache_parses.py | 53 +++++++++++++++++++ 1 file changed, 53 insertions(+) create mode 100644 surfsense_backend/alembic/versions/160_add_etl_cache_parses.py diff --git a/surfsense_backend/alembic/versions/160_add_etl_cache_parses.py b/surfsense_backend/alembic/versions/160_add_etl_cache_parses.py new file mode 100644 index 000000000..f021a962a --- /dev/null +++ b/surfsense_backend/alembic/versions/160_add_etl_cache_parses.py @@ -0,0 +1,53 @@ +"""add etl_cache_parses table for content-addressed parse reuse + +Revision ID: 160 +Revises: 159 +""" + +from collections.abc import Sequence + +from alembic import op + +revision: str = "160" +down_revision: str | None = "159" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def upgrade() -> None: + op.execute( + """ + CREATE TABLE IF NOT EXISTS etl_cache_parses ( + id SERIAL PRIMARY KEY, + source_sha256 VARCHAR(64) NOT NULL, + etl_service VARCHAR(32) NOT NULL, + mode VARCHAR(16) NOT NULL, + parser_version INTEGER NOT NULL, + storage_backend VARCHAR(32) NOT NULL, + storage_key TEXT NOT NULL, + size_bytes BIGINT NOT NULL, + content_type VARCHAR(32) NOT NULL, + actual_pages INTEGER NOT NULL DEFAULT 0, + times_reused BIGINT NOT NULL DEFAULT 0, + last_used_at TIMESTAMP WITH TIME ZONE NOT NULL, + created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(), + CONSTRAINT uq_etl_cache_parses_key + UNIQUE (source_sha256, etl_service, mode, parser_version) + ); + """ + ) + + op.execute( + "CREATE INDEX IF NOT EXISTS ix_etl_cache_parses_last_used_at " + "ON etl_cache_parses(last_used_at);" + ) + op.execute( + "CREATE INDEX IF NOT EXISTS ix_etl_cache_parses_created_at " + "ON etl_cache_parses(created_at);" + ) + + +def downgrade() -> None: + op.execute("DROP INDEX IF EXISTS ix_etl_cache_parses_created_at;") + op.execute("DROP INDEX IF EXISTS ix_etl_cache_parses_last_used_at;") + op.execute("DROP TABLE IF EXISTS etl_cache_parses;") From 5af594c405b44a75d0142e7070dbdeb013d67e7d Mon Sep 17 00:00:00 2001 From: CREDO23 Date: Fri, 12 Jun 2026 11:23:50 +0200 Subject: [PATCH 061/212] docs(env): document ETL_CACHE_* settings --- surfsense_backend/.env.example | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/surfsense_backend/.env.example b/surfsense_backend/.env.example index b4f67328c..1924756ce 100644 --- a/surfsense_backend/.env.example +++ b/surfsense_backend/.env.example @@ -311,6 +311,23 @@ FILE_STORAGE_BACKEND=local # AZURE_STORAGE_CONNECTION_STRING=DefaultEndpointsProtocol=https;AccountName=...;AccountKey=...;EndpointSuffix=core.windows.net # AZURE_STORAGE_CONTAINER=surfsense-documents +# ETL Parse Cache +# Reuse parser output for identical file bytes across workspaces (skips paid +# re-parsing on LlamaCloud / Azure DI / Unstructured). Off by default. +ETL_CACHE_ENABLED=false +# Bump to invalidate all cached entries after a parser/behaviour change. +# ETL_CACHE_PARSER_VERSION=1 +# Prune entries unused for this many days. +# ETL_CACHE_TTL_DAYS=90 +# Soft cap on total cached markdown; coldest entries are evicted past it. +# ETL_CACHE_MAX_TOTAL_MB=5120 +# Rows deleted per eviction pass. +# ETL_CACHE_EVICTION_BATCH=500 +# Optional dedicated blob storage; unset reuses the main file storage backend. +# ETL_CACHE_STORAGE_BACKEND=azure +# ETL_CACHE_STORAGE_CONTAINER=surfsense-etl-cache +# ETL_CACHE_STORAGE_LOCAL_PATH=/var/lib/surfsense/etl-cache + # Daytona Sandbox (isolated code execution) # DAYTONA_SANDBOX_ENABLED=FALSE # DAYTONA_API_KEY=your-daytona-api-key From ce1e90386fac8437a9a855e458d54b2e212e1424 Mon Sep 17 00:00:00 2001 From: CREDO23 Date: Fri, 12 Jun 2026 11:50:51 +0200 Subject: [PATCH 062/212] refactor(etl-cache): extract pure cacheability gate --- .../etl_pipeline/cache/cached_extraction.py | 14 ++++------ .../app/etl_pipeline/cache/eligibility.py | 28 +++++++++++++++++++ .../etl_pipeline/cache/schemas/parse_key.py | 2 +- 3 files changed, 35 insertions(+), 9 deletions(-) create mode 100644 surfsense_backend/app/etl_pipeline/cache/eligibility.py diff --git a/surfsense_backend/app/etl_pipeline/cache/cached_extraction.py b/surfsense_backend/app/etl_pipeline/cache/cached_extraction.py index 5348c5f4b..dba4b44da 100644 --- a/surfsense_backend/app/etl_pipeline/cache/cached_extraction.py +++ b/surfsense_backend/app/etl_pipeline/cache/cached_extraction.py @@ -7,12 +7,12 @@ import hashlib import logging from app.config import config +from app.etl_pipeline.cache.eligibility import is_parse_cacheable from app.etl_pipeline.cache.schemas import ParseKey from app.etl_pipeline.cache.service import EtlCacheService from app.etl_pipeline.cache.settings import load_etl_cache_settings from app.etl_pipeline.etl_document import EtlRequest, EtlResult from app.etl_pipeline.etl_pipeline_service import EtlPipelineService -from app.etl_pipeline.file_classifier import FileCategory, classify_file logger = logging.getLogger(__name__) @@ -25,13 +25,11 @@ async def extract_with_cache( """Drop-in for ``EtlPipelineService.extract`` that reuses prior parser output.""" settings = load_etl_cache_settings() - # Vision-LLM appends model-generated content not captured by the key, so its - # output must not be shared with plain parses (and vice versa): bypass cache. - cacheable = ( - settings.enabled - and vision_llm is None - and bool(config.ETL_SERVICE) - and classify_file(request.filename) == FileCategory.DOCUMENT + cacheable = is_parse_cacheable( + filename=request.filename, + etl_service=config.ETL_SERVICE, + cache_enabled=settings.enabled, + has_vision_llm=vision_llm is not None, ) if not cacheable: return await EtlPipelineService(vision_llm=vision_llm).extract(request) diff --git a/surfsense_backend/app/etl_pipeline/cache/eligibility.py b/surfsense_backend/app/etl_pipeline/cache/eligibility.py new file mode 100644 index 000000000..18f096218 --- /dev/null +++ b/surfsense_backend/app/etl_pipeline/cache/eligibility.py @@ -0,0 +1,28 @@ +"""Gating rule: may this upload be served from / written to the parse cache?""" + +from __future__ import annotations + +from app.etl_pipeline.file_classifier import FileCategory, classify_file + + +def is_parse_cacheable( + *, + filename: str, + etl_service: str | None, + cache_enabled: bool, + has_vision_llm: bool, +) -> bool: + """Only deterministic document parses are shareable across workspaces. + + Vision-LLM runs append model-generated content not captured by the cache key, + and a missing ETL service means there is no document parser to key against -- + both bypass the cache. Non-document categories (plaintext, audio, images, + direct-convert) are cheap or parser-agnostic and are handled outside it. + """ + if not cache_enabled: + return False + if has_vision_llm: + return False + if not etl_service: + return False + return classify_file(filename) == FileCategory.DOCUMENT diff --git a/surfsense_backend/app/etl_pipeline/cache/schemas/parse_key.py b/surfsense_backend/app/etl_pipeline/cache/schemas/parse_key.py index 65e7b08a5..88133a418 100644 --- a/surfsense_backend/app/etl_pipeline/cache/schemas/parse_key.py +++ b/surfsense_backend/app/etl_pipeline/cache/schemas/parse_key.py @@ -15,7 +15,7 @@ class ParseKey: @classmethod def for_document( cls, source_sha256: str, *, etl_service: str, mode: str, version: int - ) -> "ParseKey": + ) -> ParseKey: return cls( source_sha256=source_sha256, etl_service=etl_service, From dddacbe762b4c52652db79e3b1d4d46bf5c88c71 Mon Sep 17 00:00:00 2001 From: CREDO23 Date: Fri, 12 Jun 2026 11:50:52 +0200 Subject: [PATCH 063/212] test(etl-cache): cover content-addressing dedup and key shape --- .../tests/unit/etl_pipeline/cache/conftest.py | 28 ++++++++ .../unit/etl_pipeline/cache/test_parse_key.py | 70 +++++++++++++++++++ 2 files changed, 98 insertions(+) create mode 100644 surfsense_backend/tests/unit/etl_pipeline/cache/conftest.py create mode 100644 surfsense_backend/tests/unit/etl_pipeline/cache/test_parse_key.py diff --git a/surfsense_backend/tests/unit/etl_pipeline/cache/conftest.py b/surfsense_backend/tests/unit/etl_pipeline/cache/conftest.py new file mode 100644 index 000000000..c6efddc09 --- /dev/null +++ b/surfsense_backend/tests/unit/etl_pipeline/cache/conftest.py @@ -0,0 +1,28 @@ +"""Stub the cache package __init__s so unit tests import only pure leaf modules. + +The real ``cache``/``storage``/``eviction``/``persistence`` __init__s eagerly +import the facade, file storage, Celery, and ``app.db`` -- none of which a pure +unit test should need. Turning those packages into bare namespace packages lets +``from app.etl_pipeline.cache.. import ...`` resolve the leaf module +without running the heavy __init__. ``schemas`` is left real (it is pure). +""" + +import sys +import types +from pathlib import Path + +_CACHE_DIR = Path(__file__).resolve().parents[4] / "app" / "etl_pipeline" / "cache" + + +def _stub_namespace_package(dotted: str, fs_dir: Path) -> None: + if dotted in sys.modules: + return + module = types.ModuleType(dotted) + module.__path__ = [str(fs_dir)] + module.__package__ = dotted + sys.modules[dotted] = module + + +_stub_namespace_package("app.etl_pipeline.cache", _CACHE_DIR) +_stub_namespace_package("app.etl_pipeline.cache.storage", _CACHE_DIR / "storage") +_stub_namespace_package("app.etl_pipeline.cache.eviction", _CACHE_DIR / "eviction") diff --git a/surfsense_backend/tests/unit/etl_pipeline/cache/test_parse_key.py b/surfsense_backend/tests/unit/etl_pipeline/cache/test_parse_key.py new file mode 100644 index 000000000..d69e74ee0 --- /dev/null +++ b/surfsense_backend/tests/unit/etl_pipeline/cache/test_parse_key.py @@ -0,0 +1,70 @@ +"""Content-addressing: equal (bytes + recipe) must map to one storage location. + +This is the dedup guarantee the whole cache rests on -- two users uploading the +same file under the same parser settings have to land on the same object key, and +any change to bytes or recipe has to land somewhere else. +""" + +from __future__ import annotations + +import pytest + +from app.etl_pipeline.cache.schemas import ParseKey +from app.etl_pipeline.cache.storage.object_keys import ( + CACHE_PREFIX, + build_parse_object_key, +) + +pytestmark = pytest.mark.unit + + +def _key(**overrides) -> ParseKey: + base = { + "source_sha256": "a" * 64, + "etl_service": "LLAMACLOUD", + "mode": "basic", + "version": 1, + } + base.update(overrides) + return ParseKey.for_document( + base["source_sha256"], + etl_service=base["etl_service"], + mode=base["mode"], + version=base["version"], + ) + + +def test_same_bytes_and_recipe_produce_the_same_object_key(): + assert build_parse_object_key(_key()) == build_parse_object_key(_key()) + + +def test_different_bytes_produce_different_object_keys(): + assert build_parse_object_key( + _key(source_sha256="a" * 64) + ) != build_parse_object_key(_key(source_sha256="b" * 64)) + + +@pytest.mark.parametrize( + "field, value", + [ + ("etl_service", "DOCLING"), + ("mode", "premium"), + ("version", 2), + ], +) +def test_any_recipe_change_produces_a_different_object_key(field, value): + # Same bytes but a different parser/mode/version must not collide: the recipe + # is part of the identity, so changing it has to re-parse, not reuse. + assert build_parse_object_key(_key()) != build_parse_object_key( + _key(**{field: value}) + ) + + +def test_object_key_is_prefixed_and_sharded_by_source_hash(): + # Shape matters operationally: a dedicated top-level prefix keeps cache blobs + # out of the normal store, and the sha directory groups every recipe variant + # of one file together. + key = _key() + assert build_parse_object_key(key) == ( + f"{CACHE_PREFIX}/{key.source_sha256}/LLAMACLOUD.basic.v1.md" + ) From a3e7047c3508cd65bd48eee68b6d9f8ca53a7a9c Mon Sep 17 00:00:00 2001 From: CREDO23 Date: Fri, 12 Jun 2026 11:50:52 +0200 Subject: [PATCH 064/212] test(etl-cache): cover cacheability gate rules --- .../etl_pipeline/cache/test_eligibility.py | 88 +++++++++++++++++++ 1 file changed, 88 insertions(+) create mode 100644 surfsense_backend/tests/unit/etl_pipeline/cache/test_eligibility.py diff --git a/surfsense_backend/tests/unit/etl_pipeline/cache/test_eligibility.py b/surfsense_backend/tests/unit/etl_pipeline/cache/test_eligibility.py new file mode 100644 index 000000000..99d8e67b6 --- /dev/null +++ b/surfsense_backend/tests/unit/etl_pipeline/cache/test_eligibility.py @@ -0,0 +1,88 @@ +"""What is allowed into the cache -- the gating rules, as pure logic. + +These rules decide whether a given upload may be served from / written to the +parse cache. They live in a pure predicate so every branch (disabled, vision, +no service, file category) is covered here without touching DB, storage, or the +parser. +""" + +from __future__ import annotations + +import pytest + +from app.etl_pipeline.cache.eligibility import is_parse_cacheable + +pytestmark = pytest.mark.unit + + +def test_document_with_service_and_cache_on_is_cacheable(): + assert is_parse_cacheable( + filename="report.pdf", + etl_service="LLAMACLOUD", + cache_enabled=True, + has_vision_llm=False, + ) + + +def test_disabled_cache_is_never_cacheable(): + assert not is_parse_cacheable( + filename="report.pdf", + etl_service="LLAMACLOUD", + cache_enabled=False, + has_vision_llm=False, + ) + + +def test_vision_llm_run_is_not_cacheable(): + # Vision appends model output not captured by the key; sharing it would leak + # one run's generated text into a plain parse of the same bytes. + assert not is_parse_cacheable( + filename="report.pdf", + etl_service="LLAMACLOUD", + cache_enabled=True, + has_vision_llm=True, + ) + + +@pytest.mark.parametrize("etl_service", [None, ""]) +def test_missing_etl_service_is_not_cacheable(etl_service): + assert not is_parse_cacheable( + filename="report.pdf", + etl_service=etl_service, + cache_enabled=True, + has_vision_llm=False, + ) + + +@pytest.mark.parametrize( + "filename", + ["paper.pdf", "memo.docx", "slides.pptx", "sheet.xlsx", "book.epub"], +) +def test_document_extensions_are_cacheable(filename): + assert is_parse_cacheable( + filename=filename, + etl_service="LLAMACLOUD", + cache_enabled=True, + has_vision_llm=False, + ) + + +@pytest.mark.parametrize( + "filename", + [ + "notes.txt", # plaintext + "readme.md", # plaintext + "main.py", # plaintext + "podcast.mp3", # audio + "photo.png", # image (vision path / fallback, not a shared doc parse) + "data.csv", # direct-convert + "archive.xyz", # unsupported + ], +) +def test_non_document_categories_are_not_cacheable(filename): + assert not is_parse_cacheable( + filename=filename, + etl_service="LLAMACLOUD", + cache_enabled=True, + has_vision_llm=False, + ) From 3dec3231d0f845225a2502d8a261afcdc7d73b80 Mon Sep 17 00:00:00 2001 From: CREDO23 Date: Fri, 12 Jun 2026 11:50:52 +0200 Subject: [PATCH 065/212] test(etl-cache): cover over-budget eviction selection --- .../cache/test_eviction_policy.py | 76 +++++++++++++++++++ 1 file changed, 76 insertions(+) create mode 100644 surfsense_backend/tests/unit/etl_pipeline/cache/test_eviction_policy.py diff --git a/surfsense_backend/tests/unit/etl_pipeline/cache/test_eviction_policy.py b/surfsense_backend/tests/unit/etl_pipeline/cache/test_eviction_policy.py new file mode 100644 index 000000000..5113d7c42 --- /dev/null +++ b/surfsense_backend/tests/unit/etl_pipeline/cache/test_eviction_policy.py @@ -0,0 +1,76 @@ +"""Size-based eviction: drop just enough of the coldest entries to fit budget. + +The caller supplies candidates already ordered coldest-first; this pure rule only +decides how far down that list to cut. It must never over-evict (stop as soon as +the footprint fits) and never promise more than the candidates can free. +""" + +from __future__ import annotations + +from datetime import UTC, datetime + +import pytest + +from app.etl_pipeline.cache.eviction.policy import select_over_budget +from app.etl_pipeline.cache.schemas import EvictionCandidate + +pytestmark = pytest.mark.unit + + +def _candidate(id_: int, size_bytes: int) -> EvictionCandidate: + return EvictionCandidate( + id=id_, + storage_key=f"etl_cache/{id_}.md", + size_bytes=size_bytes, + last_used_at=datetime(2026, 1, 1, tzinfo=UTC), + times_reused=0, + ) + + +def test_over_budget_drops_coldest_until_it_fits(): + # 300 used, budget 100 -> must free >=200. Coldest-first [120, 90, 70]; + # 120+90=210 >=200, so the third (70) is spared. + coldest_first = [_candidate(1, 120), _candidate(2, 90), _candidate(3, 70)] + + chosen = select_over_budget( + coldest_first, current_total_bytes=300, max_total_bytes=100 + ) + + assert [c.id for c in chosen] == [1, 2] + + +@pytest.mark.parametrize("current_total_bytes", [100, 80]) +def test_within_budget_evicts_nothing(current_total_bytes): + # At or under budget there is nothing to free, so no blob is touched. + coldest_first = [_candidate(1, 50), _candidate(2, 50)] + + chosen = select_over_budget( + coldest_first, + current_total_bytes=current_total_bytes, + max_total_bytes=100, + ) + + assert chosen == [] + + +def test_stops_as_soon_as_one_entry_covers_the_overage(): + # Only 10 over budget; the first (cold) entry already frees enough. + coldest_first = [_candidate(1, 40), _candidate(2, 40)] + + chosen = select_over_budget( + coldest_first, current_total_bytes=110, max_total_bytes=100 + ) + + assert [c.id for c in chosen] == [1] + + +def test_returns_all_candidates_when_they_cannot_free_enough(): + # Deficit is 500 but candidates only total 150: return everything available + # rather than looping forever or raising. + coldest_first = [_candidate(1, 100), _candidate(2, 50)] + + chosen = select_over_budget( + coldest_first, current_total_bytes=600, max_total_bytes=100 + ) + + assert [c.id for c in chosen] == [1, 2] From c49a0f1233bbce2fe556eeaac95654ad60ec09df Mon Sep 17 00:00:00 2001 From: CREDO23 Date: Fri, 12 Jun 2026 11:50:57 +0200 Subject: [PATCH 066/212] test(etl-cache): cover store, service, and repository on real infra --- .../etl_pipeline/cache/conftest.py | 32 +++++++ .../cache/test_cached_parse_repository.py | 96 +++++++++++++++++++ .../cache/test_etl_cache_service.py | 67 +++++++++++++ .../etl_pipeline/cache/test_markdown_store.py | 42 ++++++++ 4 files changed, 237 insertions(+) create mode 100644 surfsense_backend/tests/integration/etl_pipeline/cache/conftest.py create mode 100644 surfsense_backend/tests/integration/etl_pipeline/cache/test_cached_parse_repository.py create mode 100644 surfsense_backend/tests/integration/etl_pipeline/cache/test_etl_cache_service.py create mode 100644 surfsense_backend/tests/integration/etl_pipeline/cache/test_markdown_store.py diff --git a/surfsense_backend/tests/integration/etl_pipeline/cache/conftest.py b/surfsense_backend/tests/integration/etl_pipeline/cache/conftest.py new file mode 100644 index 000000000..4369cc64d --- /dev/null +++ b/surfsense_backend/tests/integration/etl_pipeline/cache/conftest.py @@ -0,0 +1,32 @@ +"""Real-infra fixtures for the parse-cache integration tests. + +``cache_local_storage`` points the cache's blob store at a throwaway directory so +tests exercise the real ``LocalFileBackend`` (no cloud, no mocks). ``clean_cache_table`` +removes rows written through the facade's own committing session, which the +savepoint-rolled-back ``db_session`` cannot undo. +""" + +from __future__ import annotations + +import pytest +import pytest_asyncio +from sqlalchemy import text + + +@pytest.fixture +def cache_local_storage(tmp_path, monkeypatch): + from app.config import config + from app.etl_pipeline.cache.storage.backend import resolve_cache_backend + + monkeypatch.setattr(config, "ETL_CACHE_STORAGE_BACKEND", "local") + monkeypatch.setattr(config, "ETL_CACHE_STORAGE_LOCAL_PATH", str(tmp_path)) + resolve_cache_backend.cache_clear() + yield tmp_path + resolve_cache_backend.cache_clear() + + +@pytest_asyncio.fixture +async def clean_cache_table(async_engine): + yield + async with async_engine.begin() as conn: + await conn.execute(text("DELETE FROM etl_cache_parses")) diff --git a/surfsense_backend/tests/integration/etl_pipeline/cache/test_cached_parse_repository.py b/surfsense_backend/tests/integration/etl_pipeline/cache/test_cached_parse_repository.py new file mode 100644 index 000000000..72e977f11 --- /dev/null +++ b/surfsense_backend/tests/integration/etl_pipeline/cache/test_cached_parse_repository.py @@ -0,0 +1,96 @@ +"""CachedParseRepository against real Postgres: the SQL behind eviction & dedup. + +These verify the parts that only a real database can: coldest-first ordering by +reuse then recency, TTL cutoff selection, the size accumulator, and the +insert-once guarantee under a duplicate key. +""" + +from __future__ import annotations + +from datetime import UTC, datetime, timedelta + +import pytest + +from app.etl_pipeline.cache.persistence import CachedParseRepository +from app.etl_pipeline.cache.schemas import ParseKey + +pytestmark = pytest.mark.integration + + +def _key(sha: str) -> ParseKey: + return ParseKey.for_document( + sha, etl_service="LLAMACLOUD", mode="basic", version=1 + ) + + +async def _insert(repo, *, sha, size=100, storage_key=None): + key = _key(sha) + await repo.insert( + key=key, + content_type="application/pdf", + actual_pages=1, + storage_backend="local", + storage_key=storage_key or f"etl_cache/{sha}.md", + size_bytes=size, + ) + return key + + +async def test_total_size_bytes_sums_all_rows(db_session): + repo = CachedParseRepository(db_session) + await _insert(repo, sha="a" * 64, size=100) + await _insert(repo, sha="b" * 64, size=250) + + assert await repo.total_size_bytes() == 350 + + +async def test_select_coldest_orders_by_reuse_then_recency(db_session): + repo = CachedParseRepository(db_session) + ka = await _insert(repo, sha="a" * 64) + kb = await _insert(repo, sha="b" * 64) + kc = await _insert(repo, sha="c" * 64) + + # Warm B once and C twice; A stays untouched and should be coldest. + await repo.mark_used((await repo.get(kb)).id) + await repo.mark_used((await repo.get(kc)).id) + await repo.mark_used((await repo.get(kc)).id) + + coldest = await repo.select_coldest(limit=10) + + ids_by_reuse = [c.id for c in coldest] + assert ids_by_reuse[:3] == [ + (await repo.get(ka)).id, + (await repo.get(kb)).id, + (await repo.get(kc)).id, + ] + + +async def test_select_expired_returns_only_rows_older_than_cutoff(db_session): + repo = CachedParseRepository(db_session) + await _insert(repo, sha="a" * 64) + + future = datetime.now(UTC) + timedelta(days=1) + past = datetime.now(UTC) - timedelta(days=1) + + # Row was just used, so it's older than a future cutoff but not a past one. + assert len(await repo.select_expired(cutoff=future, limit=10)) == 1 + assert await repo.select_expired(cutoff=past, limit=10) == [] + + +async def test_duplicate_key_insert_keeps_the_first_row(db_session): + repo = CachedParseRepository(db_session) + key = await _insert(repo, sha="a" * 64, size=100, storage_key="etl_cache/first.md") + + # Same content-addressed key (a concurrent re-parse): must be a no-op. + await repo.insert( + key=key, + content_type="application/pdf", + actual_pages=1, + storage_backend="local", + storage_key="etl_cache/second.md", + size_bytes=999, + ) + + row = await repo.get(key) + assert row.storage_key == "etl_cache/first.md" + assert await repo.total_size_bytes() == 100 diff --git a/surfsense_backend/tests/integration/etl_pipeline/cache/test_etl_cache_service.py b/surfsense_backend/tests/integration/etl_pipeline/cache/test_etl_cache_service.py new file mode 100644 index 000000000..df74c97d4 --- /dev/null +++ b/surfsense_backend/tests/integration/etl_pipeline/cache/test_etl_cache_service.py @@ -0,0 +1,67 @@ +"""EtlCacheService end-to-end against real Postgres + real local storage. + +Exercises the public cache surface -- ``recall`` / ``remember`` -- with no mocks: +a miss returns nothing, and a remembered parse comes back as an equivalent +``EtlResult`` rebuilt from the row and the blob. +""" + +from __future__ import annotations + +import pytest + +from app.etl_pipeline.cache.schemas import ParseKey +from app.etl_pipeline.cache.service import EtlCacheService +from app.etl_pipeline.etl_document import EtlResult + +pytestmark = pytest.mark.integration + + +def _key(sha: str = "c" * 64) -> ParseKey: + return ParseKey.for_document( + sha, etl_service="LLAMACLOUD", mode="basic", version=1 + ) + + +async def test_recall_is_a_miss_for_an_unknown_key(db_session, cache_local_storage): + service = EtlCacheService(db_session) + assert await service.recall(_key()) is None + + +async def test_remembered_parse_recalls_as_equivalent_result( + db_session, cache_local_storage +): + service = EtlCacheService(db_session) + stored = EtlResult( + markdown_content="# Cached doc\n\nBody paragraph.\n", + etl_service="LLAMACLOUD", + actual_pages=7, + content_type="application/pdf", + ) + + await service.remember(_key(), stored) + recalled = await service.recall(_key()) + + assert recalled is not None + assert recalled.markdown_content == stored.markdown_content + assert recalled.etl_service == "LLAMACLOUD" + assert recalled.actual_pages == 7 + assert recalled.content_type == "application/pdf" + + +async def test_repeated_recall_keeps_serving_the_same_content( + db_session, cache_local_storage +): + service = EtlCacheService(db_session) + stored = EtlResult( + markdown_content="# Stable\n", + etl_service="LLAMACLOUD", + actual_pages=1, + content_type="application/pdf", + ) + await service.remember(_key(), stored) + + first = await service.recall(_key()) + second = await service.recall(_key()) + + assert first is not None and second is not None + assert first.markdown_content == second.markdown_content == "# Stable\n" diff --git a/surfsense_backend/tests/integration/etl_pipeline/cache/test_markdown_store.py b/surfsense_backend/tests/integration/etl_pipeline/cache/test_markdown_store.py new file mode 100644 index 000000000..a9d685017 --- /dev/null +++ b/surfsense_backend/tests/integration/etl_pipeline/cache/test_markdown_store.py @@ -0,0 +1,42 @@ +"""MarkdownCacheStore against a real local filesystem backend (no mocks). + +Proves the blob side of the cache: markdown written under a content-addressed key +comes back byte-for-byte, and a delete actually removes it. +""" + +from __future__ import annotations + +import pytest + +from app.etl_pipeline.cache.schemas import ParseKey +from app.etl_pipeline.cache.storage import MarkdownCacheStore +from app.etl_pipeline.cache.storage.object_keys import build_parse_object_key + +pytestmark = pytest.mark.integration + + +def _key() -> ParseKey: + return ParseKey.for_document( + "d" * 64, etl_service="LLAMACLOUD", mode="basic", version=1 + ) + + +async def test_save_then_load_round_trips_markdown(cache_local_storage): + store = MarkdownCacheStore() + markdown = "# Title\n\nBody with unicode: café, naïve, 漢字.\n" + + storage_key = await store.save(_key(), markdown) + + assert storage_key == build_parse_object_key(_key()) + assert await store.load(storage_key) == markdown + + +async def test_delete_removes_the_blob(cache_local_storage): + store = MarkdownCacheStore() + storage_key = await store.save(_key(), "to be deleted") + + await store.delete(storage_key) + + # Eviction deleted the blob; a later read must fail rather than serve stale. + with pytest.raises(FileNotFoundError): + await store.load(storage_key) From 1460173dadac5aca8faa6ee8d5e0d5fc96a64a6e Mon Sep 17 00:00:00 2001 From: CREDO23 Date: Fri, 12 Jun 2026 11:50:57 +0200 Subject: [PATCH 067/212] test(etl-cache): cover extract_with_cache end-to-end --- .../cache/test_cached_extraction.py | 84 +++++++++++++++++++ 1 file changed, 84 insertions(+) create mode 100644 surfsense_backend/tests/integration/etl_pipeline/cache/test_cached_extraction.py diff --git a/surfsense_backend/tests/integration/etl_pipeline/cache/test_cached_extraction.py b/surfsense_backend/tests/integration/etl_pipeline/cache/test_cached_extraction.py new file mode 100644 index 000000000..0b4a3dcf0 --- /dev/null +++ b/surfsense_backend/tests/integration/etl_pipeline/cache/test_cached_extraction.py @@ -0,0 +1,84 @@ +"""extract_with_cache end-to-end: real DB + real local storage. + +The only seam mocked is the parser itself (``EtlPipelineService.extract``) -- the +external boundary the facade wraps. Everything else (eligibility, hashing, recall, +remember, blob I/O) runs for real, so these tests prove the actual cost saving: +identical bytes are parsed once and reused. +""" + +from __future__ import annotations + +import pytest + +from app.config import config +from app.etl_pipeline.cache.cached_extraction import extract_with_cache +from app.etl_pipeline.etl_document import EtlRequest, EtlResult, ProcessingMode + +pytestmark = pytest.mark.integration + + +class _CountingParser: + """Stand-in for the external parser; records how often it actually ran.""" + + def __init__(self, **_kwargs) -> None: + pass + + calls = 0 + + async def extract(self, request: EtlRequest) -> EtlResult: + type(self).calls += 1 + return EtlResult( + markdown_content="# Parsed once\n", + etl_service="LLAMACLOUD", + actual_pages=3, + content_type="application/pdf", + ) + + +@pytest.fixture +def counting_parser(monkeypatch): + _CountingParser.calls = 0 + monkeypatch.setattr( + "app.etl_pipeline.cache.cached_extraction.EtlPipelineService", + _CountingParser, + ) + return _CountingParser + + +async def test_identical_uploads_are_parsed_once_then_served_from_cache( + tmp_path, monkeypatch, counting_parser, cache_local_storage, clean_cache_table +): + monkeypatch.setattr(config, "ETL_CACHE_ENABLED", True) + monkeypatch.setattr(config, "ETL_SERVICE", "LLAMACLOUD") + + pdf = tmp_path / "doc.pdf" + pdf.write_bytes(b"%PDF-1.4 unique-bytes-for-this-test") + request = EtlRequest( + file_path=str(pdf), filename="doc.pdf", processing_mode=ProcessingMode.BASIC + ) + + first = await extract_with_cache(request) + second = await extract_with_cache(request) + + assert counting_parser.calls == 1 # second upload reused the cache + assert first.markdown_content == second.markdown_content == "# Parsed once\n" + assert second.actual_pages == 3 + assert second.content_type == "application/pdf" + + +async def test_disabled_cache_parses_every_time( + tmp_path, monkeypatch, counting_parser +): + monkeypatch.setattr(config, "ETL_CACHE_ENABLED", False) + monkeypatch.setattr(config, "ETL_SERVICE", "LLAMACLOUD") + + pdf = tmp_path / "doc.pdf" + pdf.write_bytes(b"%PDF-1.4 another-unique-payload") + request = EtlRequest( + file_path=str(pdf), filename="doc.pdf", processing_mode=ProcessingMode.BASIC + ) + + await extract_with_cache(request) + await extract_with_cache(request) + + assert counting_parser.calls == 2 # bypassed: no reuse From d5e0280097a643a34b1e18189c872e4fb1bdc04b Mon Sep 17 00:00:00 2001 From: CREDO23 Date: Fri, 12 Jun 2026 11:54:36 +0200 Subject: [PATCH 068/212] test(etl-cache): cover two-phase eviction task on real infra --- .../etl_pipeline/cache/test_eviction_task.py | 96 +++++++++++++++++++ 1 file changed, 96 insertions(+) create mode 100644 surfsense_backend/tests/integration/etl_pipeline/cache/test_eviction_task.py diff --git a/surfsense_backend/tests/integration/etl_pipeline/cache/test_eviction_task.py b/surfsense_backend/tests/integration/etl_pipeline/cache/test_eviction_task.py new file mode 100644 index 000000000..e25cfaef0 --- /dev/null +++ b/surfsense_backend/tests/integration/etl_pipeline/cache/test_eviction_task.py @@ -0,0 +1,96 @@ +"""The eviction task on real infra: TTL expiry first, then coldest-over-budget. + +Seeds entries through the real cache (DB rows + local blobs), runs the actual +``_evict`` coroutine, and checks what survives via ``recall`` -- no mocks. TTL and +budget are driven through config so each phase can be exercised in isolation. +""" + +from __future__ import annotations + +import pytest + +from app.config import config +from app.etl_pipeline.cache.eviction.task import _evict +from app.etl_pipeline.cache.schemas import ParseKey +from app.etl_pipeline.cache.service import EtlCacheService +from app.etl_pipeline.etl_document import EtlResult +from app.tasks.celery_tasks import get_celery_session_maker + +pytestmark = pytest.mark.integration + + +def _key(sha: str) -> ParseKey: + return ParseKey.for_document( + sha, etl_service="LLAMACLOUD", mode="basic", version=1 + ) + + +def _result(markdown: str) -> EtlResult: + return EtlResult( + markdown_content=markdown, + etl_service="LLAMACLOUD", + actual_pages=1, + content_type="application/pdf", + ) + + +async def _remember(key: ParseKey, result: EtlResult) -> None: + async with get_celery_session_maker()() as session: + await EtlCacheService(session).remember(key, result) + + +async def _recall(key: ParseKey) -> EtlResult | None: + async with get_celery_session_maker()() as session: + return await EtlCacheService(session).recall(key) + + +async def test_expired_entries_are_pruned( + monkeypatch, cache_local_storage, clean_cache_table +): + monkeypatch.setattr(config, "ETL_CACHE_ENABLED", True) + monkeypatch.setattr(config, "ETL_CACHE_TTL_DAYS", -1) # cutoff in the future -> stale + monkeypatch.setattr(config, "ETL_CACHE_MAX_TOTAL_MB", 10_000) # size phase no-op + + key = _key("a" * 64) + await _remember(key, _result("# stale doc\n")) + + await _evict() + + assert await _recall(key) is None + + +async def test_coldest_entries_are_shed_when_over_budget( + monkeypatch, cache_local_storage, clean_cache_table +): + monkeypatch.setattr(config, "ETL_CACHE_ENABLED", True) + monkeypatch.setattr(config, "ETL_CACHE_TTL_DAYS", 3650) # nothing TTL-expired + monkeypatch.setattr(config, "ETL_CACHE_MAX_TOTAL_MB", 1) # ~1 MiB budget + + cold = _key("a" * 64) + warm = _key("b" * 64) + # Two ~0.6 MiB entries together exceed the 1 MiB budget; one must go. + await _remember(cold, _result("x" * 600_000)) + await _remember(warm, _result("y" * 600_000)) + + # A reuse makes `warm` warmer than `cold`, so `cold` is the eviction target. + assert await _recall(warm) is not None + + await _evict() + + assert await _recall(cold) is None + assert await _recall(warm) is not None + + +async def test_nothing_is_evicted_within_ttl_and_budget( + monkeypatch, cache_local_storage, clean_cache_table +): + monkeypatch.setattr(config, "ETL_CACHE_ENABLED", True) + monkeypatch.setattr(config, "ETL_CACHE_TTL_DAYS", 3650) + monkeypatch.setattr(config, "ETL_CACHE_MAX_TOTAL_MB", 10_000) + + key = _key("a" * 64) + await _remember(key, _result("# keep me\n")) + + await _evict() + + assert await _recall(key) is not None From 9efe24879dca2e19c799b102ea13524a3822a962 Mon Sep 17 00:00:00 2001 From: CREDO23 Date: Fri, 12 Jun 2026 11:57:03 +0200 Subject: [PATCH 069/212] feat(observability): add etl cache lookup and eviction metrics --- .../app/observability/metrics.py | 40 +++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/surfsense_backend/app/observability/metrics.py b/surfsense_backend/app/observability/metrics.py index 5ba3be059..4751278a4 100644 --- a/surfsense_backend/app/observability/metrics.py +++ b/surfsense_backend/app/observability/metrics.py @@ -289,6 +289,22 @@ def _etl_extract_outcome(): ) +@lru_cache(maxsize=1) +def _etl_cache_lookups(): + return _get_meter().create_counter( + "surfsense.etl.cache.lookups", + description="Count of ETL parse-cache lookups by outcome (hit/miss).", + ) + + +@lru_cache(maxsize=1) +def _etl_cache_evictions(): + return _get_meter().create_counter( + "surfsense.etl.cache.evictions", + description="Count of ETL parse-cache entries evicted, by phase.", + ) + + @lru_cache(maxsize=1) def _celery_heartbeat_refreshes(): return _get_meter().create_counter( @@ -670,6 +686,28 @@ def record_etl_extract_outcome( ) +def record_etl_cache_lookup( + *, etl_service: str | None, mode: str | None, outcome: str +) -> None: + """Record a parse-cache lookup. ``outcome`` is ``hit`` or ``miss``.""" + _add( + _etl_cache_lookups(), + 1, + { + "etl.service": etl_service or "unknown", + "mode": mode or "unknown", + "outcome": outcome, + }, + ) + + +def record_etl_cache_eviction(count: int, *, phase: str) -> None: + """Record evicted entries. ``phase`` is ``ttl`` or ``size``.""" + if count <= 0: + return + _add(_etl_cache_evictions(), count, {"phase": phase}) + + def record_celery_heartbeat_refresh(*, heartbeat_type: str) -> None: _add(_celery_heartbeat_refreshes(), 1, {"heartbeat.type": heartbeat_type}) @@ -866,6 +904,8 @@ __all__ = [ "record_compaction_run", "record_connector_sync_duration", "record_connector_sync_outcome", + "record_etl_cache_eviction", + "record_etl_cache_lookup", "record_etl_extract_duration", "record_etl_extract_outcome", "record_indexing_document_duration", From 0808fbcdee7d16755e3410604e6b61ebe54e2a67 Mon Sep 17 00:00:00 2001 From: CREDO23 Date: Fri, 12 Jun 2026 11:57:03 +0200 Subject: [PATCH 070/212] feat(etl-cache): emit hit/miss and eviction metrics --- .../app/etl_pipeline/cache/cached_extraction.py | 8 ++++++++ .../app/etl_pipeline/cache/eviction/task.py | 10 +++++++--- 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/surfsense_backend/app/etl_pipeline/cache/cached_extraction.py b/surfsense_backend/app/etl_pipeline/cache/cached_extraction.py index dba4b44da..b6a9e5531 100644 --- a/surfsense_backend/app/etl_pipeline/cache/cached_extraction.py +++ b/surfsense_backend/app/etl_pipeline/cache/cached_extraction.py @@ -13,6 +13,7 @@ from app.etl_pipeline.cache.service import EtlCacheService from app.etl_pipeline.cache.settings import load_etl_cache_settings from app.etl_pipeline.etl_document import EtlRequest, EtlResult from app.etl_pipeline.etl_pipeline_service import EtlPipelineService +from app.observability import metrics logger = logging.getLogger(__name__) @@ -43,8 +44,15 @@ async def extract_with_cache( cached_result = await _recall(key) if cached_result is not None: + metrics.record_etl_cache_lookup( + etl_service=key.etl_service, mode=key.mode, outcome="hit" + ) + logger.debug("ETL cache hit for %s", key.source_sha256) return cached_result + metrics.record_etl_cache_lookup( + etl_service=key.etl_service, mode=key.mode, outcome="miss" + ) result = await EtlPipelineService(vision_llm=vision_llm).extract(request) await _remember(key, result) return result diff --git a/surfsense_backend/app/etl_pipeline/cache/eviction/task.py b/surfsense_backend/app/etl_pipeline/cache/eviction/task.py index 98841b139..dcda10f61 100644 --- a/surfsense_backend/app/etl_pipeline/cache/eviction/task.py +++ b/surfsense_backend/app/etl_pipeline/cache/eviction/task.py @@ -12,6 +12,7 @@ from app.etl_pipeline.cache.persistence import CachedParseRepository from app.etl_pipeline.cache.schemas import EvictionCandidate from app.etl_pipeline.cache.settings import load_etl_cache_settings from app.etl_pipeline.cache.storage import MarkdownCacheStore +from app.observability import metrics from app.tasks.celery_tasks import get_celery_session_maker, run_async_celery_task logger = logging.getLogger(__name__) @@ -34,7 +35,7 @@ async def _evict() -> None: cutoff = datetime.now(UTC) - timedelta(days=settings.ttl_days) expired = await index.select_expired(cutoff=cutoff, limit=settings.eviction_batch) - await _drop(index, store, expired) + await _drop(index, store, expired, phase="ttl") total = await index.total_size_bytes() if total > settings.max_total_bytes: @@ -44,13 +45,15 @@ async def _evict() -> None: current_total_bytes=total, max_total_bytes=settings.max_total_bytes, ) - await _drop(index, store, over_budget) + await _drop(index, store, over_budget, phase="size") async def _drop( index: CachedParseRepository, store: MarkdownCacheStore, candidates: list[EvictionCandidate], + *, + phase: str, ) -> None: if not candidates: return @@ -59,4 +62,5 @@ async def _drop( with contextlib.suppress(Exception): await store.delete(candidate.storage_key) await index.delete_by_ids([candidate.id for candidate in candidates]) - logger.info("Evicted %d cached parses", len(candidates)) + metrics.record_etl_cache_eviction(len(candidates), phase=phase) + logger.info("Evicted %d cached parses (%s)", len(candidates), phase) From 99cf212c3160ab6e99af2e72718ed1d5e88e4c85 Mon Sep 17 00:00:00 2001 From: CREDO23 Date: Fri, 12 Jun 2026 12:19:49 +0200 Subject: [PATCH 071/212] test: fix auth-mode mismatch and stale QuotaInsufficientError kwargs Pin AUTH_TYPE=LOCAL (and REGISTRATION_ENABLED=TRUE) in the test bootstrap so the email/password auth routers mount during integration tests regardless of a developer's .env=GOOGLE; without this the upload tests 404 on registration. Also update three tests to the current QuotaInsufficientError signature (balance_micros) after used_micros/limit_micros were removed. --- surfsense_backend/tests/conftest.py | 8 ++++++++ .../tests/integration/podcasts/test_draft_task.py | 3 +-- .../tests/unit/services/test_quota_checked_vision_llm.py | 3 +-- .../tests/unit/tasks/test_video_presentation_billing.py | 3 +-- 4 files changed, 11 insertions(+), 6 deletions(-) diff --git a/surfsense_backend/tests/conftest.py b/surfsense_backend/tests/conftest.py index e2b586aa2..e227ed287 100644 --- a/surfsense_backend/tests/conftest.py +++ b/surfsense_backend/tests/conftest.py @@ -13,6 +13,14 @@ TEST_DATABASE_URL = os.environ.get("TEST_DATABASE_URL", _DEFAULT_TEST_DB) # DATABASE_URL in the environment (e.g. from .env or shell profile). os.environ["DATABASE_URL"] = TEST_DATABASE_URL +# Integration tests authenticate over HTTP via email/password, so the +# password-auth routers must be mounted (they are skipped under AUTH_TYPE=GOOGLE). +# setdefault (not load_dotenv, which runs later with override=False) lets a +# developer's .env=GOOGLE be overridden here while still honouring an explicitly +# exported shell AUTH_TYPE. +os.environ.setdefault("AUTH_TYPE", "LOCAL") +os.environ.setdefault("REGISTRATION_ENABLED", "TRUE") + import pytest # noqa: E402 from app.db import DocumentType # noqa: E402 diff --git a/surfsense_backend/tests/integration/podcasts/test_draft_task.py b/surfsense_backend/tests/integration/podcasts/test_draft_task.py index 7dadfc2f5..e9c9e4a9c 100644 --- a/surfsense_backend/tests/integration/podcasts/test_draft_task.py +++ b/surfsense_backend/tests/integration/podcasts/test_draft_task.py @@ -76,8 +76,7 @@ async def test_quota_denial_fails_the_podcast_without_a_transcript( async def _deny(**_kwargs): raise QuotaInsufficientError( usage_type="podcast_generation", - used_micros=5_000_000, - limit_micros=5_000_000, + balance_micros=0, remaining_micros=0, ) yield # pragma: no cover - unreachable, satisfies the CM protocol diff --git a/surfsense_backend/tests/unit/services/test_quota_checked_vision_llm.py b/surfsense_backend/tests/unit/services/test_quota_checked_vision_llm.py index 9e35b6f9c..17df89135 100644 --- a/surfsense_backend/tests/unit/services/test_quota_checked_vision_llm.py +++ b/surfsense_backend/tests/unit/services/test_quota_checked_vision_llm.py @@ -105,8 +105,7 @@ async def test_ainvoke_propagates_quota_insufficient_error(monkeypatch): async def _denying_billable_call(**_kwargs): raise QuotaInsufficientError( usage_type="vision_extraction", - used_micros=5_000_000, - limit_micros=5_000_000, + balance_micros=0, remaining_micros=0, ) yield # unreachable but required for asynccontextmanager type diff --git a/surfsense_backend/tests/unit/tasks/test_video_presentation_billing.py b/surfsense_backend/tests/unit/tasks/test_video_presentation_billing.py index 423b64ddb..97c1551a5 100644 --- a/surfsense_backend/tests/unit/tasks/test_video_presentation_billing.py +++ b/surfsense_backend/tests/unit/tasks/test_video_presentation_billing.py @@ -98,8 +98,7 @@ async def _denying_billable_call(**kwargs): _CALL_LOG.append(kwargs) raise QuotaInsufficientError( usage_type=kwargs.get("usage_type", "?"), - used_micros=5_000_000, - limit_micros=5_000_000, + balance_micros=0, remaining_micros=0, ) yield SimpleNamespace() # pragma: no cover From 0fb1d3d37b5e697bf0cf3b3286f11a8f57dfebb0 Mon Sep 17 00:00:00 2001 From: CREDO23 Date: Fri, 12 Jun 2026 14:47:25 +0200 Subject: [PATCH 072/212] feat(etl-cache): route all file-based sources through the parse cache Every file ingestion path (Dropbox, Google Drive / Composio Drive, OneDrive, local folder, Obsidian, and the legacy upload handlers) now parses via the extract_with_cache facade instead of calling EtlPipelineService.extract directly, so identical bytes are deduplicated globally regardless of source. vision_llm is passed through, keeping the existing cacheability gate intact. --- .../connectors/dropbox/content_extractor.py | 7 ++++--- .../google_drive/content_extractor.py | 9 +++++---- .../connectors/onedrive/content_extractor.py | 9 +++++---- .../app/services/obsidian_plugin_indexer.py | 7 ++++--- .../local_folder_indexer.py | 7 ++++--- .../document_processors/file_processors.py | 19 +++++++++++-------- 6 files changed, 33 insertions(+), 25 deletions(-) diff --git a/surfsense_backend/app/connectors/dropbox/content_extractor.py b/surfsense_backend/app/connectors/dropbox/content_extractor.py index 372d2fc82..300010c26 100644 --- a/surfsense_backend/app/connectors/dropbox/content_extractor.py +++ b/surfsense_backend/app/connectors/dropbox/content_extractor.py @@ -90,11 +90,12 @@ async def download_and_extract_content( if error: return None, metadata, error + from app.etl_pipeline.cache import extract_with_cache from app.etl_pipeline.etl_document import EtlRequest - from app.etl_pipeline.etl_pipeline_service import EtlPipelineService - result = await EtlPipelineService(vision_llm=vision_llm).extract( - EtlRequest(file_path=temp_file_path, filename=file_name) + result = await extract_with_cache( + EtlRequest(file_path=temp_file_path, filename=file_name), + vision_llm=vision_llm, ) markdown = result.markdown_content return markdown, metadata, None diff --git a/surfsense_backend/app/connectors/google_drive/content_extractor.py b/surfsense_backend/app/connectors/google_drive/content_extractor.py index 59392831d..1ea047978 100644 --- a/surfsense_backend/app/connectors/google_drive/content_extractor.py +++ b/surfsense_backend/app/connectors/google_drive/content_extractor.py @@ -122,12 +122,13 @@ async def download_and_extract_content( async def _parse_file_to_markdown( file_path: str, filename: str, *, vision_llm=None ) -> str: - """Parse a local file to markdown using the unified ETL pipeline.""" + """Parse a local file to markdown via the cache-aware ETL pipeline.""" + from app.etl_pipeline.cache import extract_with_cache from app.etl_pipeline.etl_document import EtlRequest - from app.etl_pipeline.etl_pipeline_service import EtlPipelineService - result = await EtlPipelineService(vision_llm=vision_llm).extract( - EtlRequest(file_path=file_path, filename=filename) + result = await extract_with_cache( + EtlRequest(file_path=file_path, filename=filename), + vision_llm=vision_llm, ) return result.markdown_content diff --git a/surfsense_backend/app/connectors/onedrive/content_extractor.py b/surfsense_backend/app/connectors/onedrive/content_extractor.py index 3154f2eca..fb1d31fbc 100644 --- a/surfsense_backend/app/connectors/onedrive/content_extractor.py +++ b/surfsense_backend/app/connectors/onedrive/content_extractor.py @@ -84,11 +84,12 @@ async def download_and_extract_content( async def _parse_file_to_markdown( file_path: str, filename: str, *, vision_llm=None ) -> str: - """Parse a local file to markdown using the unified ETL pipeline.""" + """Parse a local file to markdown via the cache-aware ETL pipeline.""" + from app.etl_pipeline.cache import extract_with_cache from app.etl_pipeline.etl_document import EtlRequest - from app.etl_pipeline.etl_pipeline_service import EtlPipelineService - result = await EtlPipelineService(vision_llm=vision_llm).extract( - EtlRequest(file_path=file_path, filename=filename) + result = await extract_with_cache( + EtlRequest(file_path=file_path, filename=filename), + vision_llm=vision_llm, ) return result.markdown_content diff --git a/surfsense_backend/app/services/obsidian_plugin_indexer.py b/surfsense_backend/app/services/obsidian_plugin_indexer.py index 13f43d1ee..cd05d7935 100644 --- a/surfsense_backend/app/services/obsidian_plugin_indexer.py +++ b/surfsense_backend/app/services/obsidian_plugin_indexer.py @@ -199,11 +199,12 @@ async def _extract_binary_attachment_markdown( async def _run_etl_extract(*, file_path: str, filename: str, vision_llm): """Lazy-load ETL dependencies to avoid module-import cycles.""" + from app.etl_pipeline.cache import extract_with_cache from app.etl_pipeline.etl_document import EtlRequest - from app.etl_pipeline.etl_pipeline_service import EtlPipelineService - return await EtlPipelineService(vision_llm=vision_llm).extract( - EtlRequest(file_path=file_path, filename=filename) + return await extract_with_cache( + EtlRequest(file_path=file_path, filename=filename), + vision_llm=vision_llm, ) diff --git a/surfsense_backend/app/tasks/connector_indexers/local_folder_indexer.py b/surfsense_backend/app/tasks/connector_indexers/local_folder_indexer.py index 1a2d4b967..2505fa7c4 100644 --- a/surfsense_backend/app/tasks/connector_indexers/local_folder_indexer.py +++ b/surfsense_backend/app/tasks/connector_indexers/local_folder_indexer.py @@ -162,12 +162,13 @@ async def _read_file_content( All file types (plaintext, audio, direct-convert, document, image) are handled by ``EtlPipelineService``. """ + from app.etl_pipeline.cache import extract_with_cache from app.etl_pipeline.etl_document import EtlRequest, ProcessingMode - from app.etl_pipeline.etl_pipeline_service import EtlPipelineService mode = ProcessingMode.coerce(processing_mode) - result = await EtlPipelineService(vision_llm=vision_llm).extract( - EtlRequest(file_path=file_path, filename=filename, processing_mode=mode) + result = await extract_with_cache( + EtlRequest(file_path=file_path, filename=filename, processing_mode=mode), + vision_llm=vision_llm, ) return result.markdown_content diff --git a/surfsense_backend/app/tasks/document_processors/file_processors.py b/surfsense_backend/app/tasks/document_processors/file_processors.py index 0c3d30766..174ac966d 100644 --- a/surfsense_backend/app/tasks/document_processors/file_processors.py +++ b/surfsense_backend/app/tasks/document_processors/file_processors.py @@ -1,8 +1,9 @@ """ File document processors orchestrating content extraction and indexing. -Delegates content extraction to ``app.etl_pipeline.EtlPipelineService`` and -keeps only orchestration concerns (notifications, logging, page limits, saving). +Delegates content extraction to the cache-aware ``extract_with_cache`` facade +(over ``EtlPipelineService``) and keeps only orchestration concerns +(notifications, logging, page limits, saving). """ from __future__ import annotations @@ -116,8 +117,8 @@ async def _log_page_divergence( async def _process_non_document_upload(ctx: _ProcessingContext) -> Document | None: """Extract content from a non-document file (plaintext/direct_convert/audio/image) via the unified ETL pipeline.""" + from app.etl_pipeline.cache import extract_with_cache from app.etl_pipeline.etl_document import EtlRequest - from app.etl_pipeline.etl_pipeline_service import EtlPipelineService await _notify(ctx, "parsing", "Processing file") await ctx.task_logger.log_task_progress( @@ -136,8 +137,9 @@ async def _process_non_document_upload(ctx: _ProcessingContext) -> Document | No vision_llm = await get_vision_llm(ctx.session, ctx.search_space_id) - etl_result = await EtlPipelineService(vision_llm=vision_llm).extract( - EtlRequest(file_path=ctx.file_path, filename=ctx.filename) + etl_result = await extract_with_cache( + EtlRequest(file_path=ctx.file_path, filename=ctx.filename), + vision_llm=vision_llm, ) with contextlib.suppress(Exception): @@ -183,8 +185,8 @@ async def _process_non_document_upload(ctx: _ProcessingContext) -> Document | No async def _process_document_upload(ctx: _ProcessingContext) -> Document | None: """Route a document file to the configured ETL service via the unified pipeline.""" + from app.etl_pipeline.cache import extract_with_cache from app.etl_pipeline.etl_document import EtlRequest, ProcessingMode - from app.etl_pipeline.etl_pipeline_service import EtlPipelineService from app.services.etl_credit_service import ( EtlCreditService, InsufficientCreditsError, @@ -237,13 +239,14 @@ async def _process_document_upload(ctx: _ProcessingContext) -> Document | None: vision_llm = await get_vision_llm(ctx.session, ctx.search_space_id) - etl_result = await EtlPipelineService(vision_llm=vision_llm).extract( + etl_result = await extract_with_cache( EtlRequest( file_path=ctx.file_path, filename=ctx.filename, estimated_pages=estimated_pages, processing_mode=mode, - ) + ), + vision_llm=vision_llm, ) with contextlib.suppress(Exception): From cf208365b471941b967d4f2514edc54bad20213c Mon Sep 17 00:00:00 2001 From: CREDO23 Date: Fri, 12 Jun 2026 16:48:01 +0200 Subject: [PATCH 073/212] feat(index-cache): add embedding set value objects --- .../cache/schemas/__init__.py | 12 ++++++++ .../cache/schemas/embedding_key.py | 27 +++++++++++++++++ .../cache/schemas/embedding_set.py | 29 +++++++++++++++++++ 3 files changed, 68 insertions(+) create mode 100644 surfsense_backend/app/indexing_pipeline/cache/schemas/__init__.py create mode 100644 surfsense_backend/app/indexing_pipeline/cache/schemas/embedding_key.py create mode 100644 surfsense_backend/app/indexing_pipeline/cache/schemas/embedding_set.py diff --git a/surfsense_backend/app/indexing_pipeline/cache/schemas/__init__.py b/surfsense_backend/app/indexing_pipeline/cache/schemas/__init__.py new file mode 100644 index 000000000..8714e2d86 --- /dev/null +++ b/surfsense_backend/app/indexing_pipeline/cache/schemas/__init__.py @@ -0,0 +1,12 @@ +"""Pure value objects for the index cache.""" + +from __future__ import annotations + +from .embedding_key import EmbeddingKey +from .embedding_set import CachedChunk, EmbeddingSet + +__all__ = [ + "CachedChunk", + "EmbeddingKey", + "EmbeddingSet", +] diff --git a/surfsense_backend/app/indexing_pipeline/cache/schemas/embedding_key.py b/surfsense_backend/app/indexing_pipeline/cache/schemas/embedding_key.py new file mode 100644 index 000000000..55d891e73 --- /dev/null +++ b/surfsense_backend/app/indexing_pipeline/cache/schemas/embedding_key.py @@ -0,0 +1,27 @@ +"""Identity of a cacheable embedding set: equal keys yield identical vectors. + +Embeddings depend on the markdown text, the embedding model, and the chunker -- +never on how the markdown was produced. So the key is the markdown's own hash +plus the model and chunker recipe, not the upstream parse identity. +""" + +from __future__ import annotations + +import hashlib +from dataclasses import dataclass + + +@dataclass(frozen=True, slots=True) +class EmbeddingKey: + markdown_sha256: str + embedding_model: str + embedding_dim: int + chunker_kind: str + chunker_version: int + + @property + def object_suffix(self) -> str: + # Fingerprint the model so distinct models never share a blob, while the + # markdown hash (the object's folder) stays human-readable. + fingerprint = hashlib.sha256(self.embedding_model.encode("utf-8")).hexdigest() + return f"{fingerprint[:16]}.{self.chunker_kind}.v{self.chunker_version}.emb" diff --git a/surfsense_backend/app/indexing_pipeline/cache/schemas/embedding_set.py b/surfsense_backend/app/indexing_pipeline/cache/schemas/embedding_set.py new file mode 100644 index 000000000..68c3a5211 --- /dev/null +++ b/surfsense_backend/app/indexing_pipeline/cache/schemas/embedding_set.py @@ -0,0 +1,29 @@ +"""The cached payload: a document's chunk texts paired with their vectors.""" + +from __future__ import annotations + +from dataclasses import dataclass + +import numpy as np + + +@dataclass(frozen=True, slots=True) +class CachedChunk: + text: str + embedding: np.ndarray + + +@dataclass(frozen=True, slots=True) +class EmbeddingSet: + """Everything the indexer needs to rebuild a document's chunks without embedding. + + ``summary_embedding`` is the document-level vector; ``chunks`` are the ordered + chunk texts and their vectors. + """ + + summary_embedding: np.ndarray + chunks: list[CachedChunk] + + @property + def chunk_count(self) -> int: + return len(self.chunks) From 59fa4c38c3613f47ae4336ed6a9f1a0d698b749a Mon Sep 17 00:00:00 2001 From: CREDO23 Date: Fri, 12 Jun 2026 16:48:01 +0200 Subject: [PATCH 074/212] feat(index-cache): add pickle-free blob serialization --- .../indexing_pipeline/cache/serialization.py | 71 +++++++++++++++++++ 1 file changed, 71 insertions(+) create mode 100644 surfsense_backend/app/indexing_pipeline/cache/serialization.py diff --git a/surfsense_backend/app/indexing_pipeline/cache/serialization.py b/surfsense_backend/app/indexing_pipeline/cache/serialization.py new file mode 100644 index 000000000..f9d53b471 --- /dev/null +++ b/surfsense_backend/app/indexing_pipeline/cache/serialization.py @@ -0,0 +1,71 @@ +"""Serialize an EmbeddingSet to a compact, self-describing blob (no pickle). + +Layout: ``MAGIC | uint32 header_len | json header | float32 matrix``. The header +carries the dim, chunk count, and ordered chunk texts; the matrix holds the +summary vector followed by one row per chunk, all float32 for compactness. +""" + +from __future__ import annotations + +import json +import struct + +import numpy as np + +from app.indexing_pipeline.cache.schemas import CachedChunk, EmbeddingSet + +# Marker at the start of every blob: "SurfSense EMBeddings, version 1"-> SSEMB1. Lets us +# reject foreign blobs and bump the trailing digit if the layout ever changes. +_MAGIC = b"SSEMB1" +# 4-byte big-endian unsigned int written before the variable-length JSON header, +# so the reader knows where the header ends and the float matrix begins. +_HEADER_LEN = struct.Struct(">I") + + +def serialize(embedding_set: EmbeddingSet) -> bytes: + summary = np.asarray(embedding_set.summary_embedding, dtype=np.float32).reshape(-1) + dim = int(summary.shape[0]) + + rows = [summary] + texts: list[str] = [] + for chunk in embedding_set.chunks: + vector = np.asarray(chunk.embedding, dtype=np.float32).reshape(-1) + if vector.shape[0] != dim: + raise ValueError("All vectors in an embedding set must share one dimension.") + rows.append(vector) + texts.append(chunk.text) + + matrix = np.stack(rows, axis=0) + header = json.dumps( + {"dim": dim, "count": len(texts), "texts": texts}, ensure_ascii=False + ).encode("utf-8") + return b"".join( + [_MAGIC, _HEADER_LEN.pack(len(header)), header, matrix.tobytes(order="C")] + ) + + +def deserialize(blob: bytes) -> EmbeddingSet: + view = memoryview(blob) + if bytes(view[: len(_MAGIC)]) != _MAGIC: + raise ValueError("Unrecognized embedding cache blob.") + + offset = len(_MAGIC) + (header_len,) = _HEADER_LEN.unpack(view[offset : offset + _HEADER_LEN.size]) + offset += _HEADER_LEN.size + + header = json.loads(bytes(view[offset : offset + header_len]).decode("utf-8")) + offset += header_len + + dim = int(header["dim"]) + count = int(header["count"]) + texts: list[str] = header["texts"] + + matrix = np.frombuffer(view[offset:], dtype=np.float32) + if matrix.shape[0] != (count + 1) * dim: + raise ValueError("Embedding cache blob is truncated or corrupt.") + matrix = matrix.reshape(count + 1, dim) + + return EmbeddingSet( + summary_embedding=matrix[0], + chunks=[CachedChunk(text=texts[i], embedding=matrix[i + 1]) for i in range(count)], + ) From f5411145446aa253cfe21802ce1666f484a0a006 Mon Sep 17 00:00:00 2001 From: CREDO23 Date: Fri, 12 Jun 2026 16:48:01 +0200 Subject: [PATCH 075/212] feat(index-cache): add cached embedding set table and repository --- .../161_add_index_cache_embedding_sets.py | 53 ++++++++ surfsense_backend/app/db.py | 3 + .../cache/persistence/__init__.py | 11 ++ .../cache/persistence/models.py | 47 +++++++ .../cache/persistence/repository.py | 126 ++++++++++++++++++ 5 files changed, 240 insertions(+) create mode 100644 surfsense_backend/alembic/versions/161_add_index_cache_embedding_sets.py create mode 100644 surfsense_backend/app/indexing_pipeline/cache/persistence/__init__.py create mode 100644 surfsense_backend/app/indexing_pipeline/cache/persistence/models.py create mode 100644 surfsense_backend/app/indexing_pipeline/cache/persistence/repository.py diff --git a/surfsense_backend/alembic/versions/161_add_index_cache_embedding_sets.py b/surfsense_backend/alembic/versions/161_add_index_cache_embedding_sets.py new file mode 100644 index 000000000..8441dcf6e --- /dev/null +++ b/surfsense_backend/alembic/versions/161_add_index_cache_embedding_sets.py @@ -0,0 +1,53 @@ +"""add index_cache_embedding_sets table for content-addressed embedding reuse + +Revision ID: 161 +Revises: 160 +""" + +from collections.abc import Sequence + +from alembic import op + +revision: str = "161" +down_revision: str | None = "160" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def upgrade() -> None: + op.execute( + """ + CREATE TABLE IF NOT EXISTS index_cache_embedding_sets ( + id SERIAL PRIMARY KEY, + markdown_sha256 VARCHAR(64) NOT NULL, + embedding_model VARCHAR(255) NOT NULL, + embedding_dim INTEGER NOT NULL, + chunker_kind VARCHAR(8) NOT NULL, + chunker_version INTEGER NOT NULL, + storage_backend VARCHAR(32) NOT NULL, + storage_key TEXT NOT NULL, + size_bytes BIGINT NOT NULL, + chunk_count INTEGER NOT NULL DEFAULT 0, + times_reused BIGINT NOT NULL DEFAULT 0, + last_used_at TIMESTAMP WITH TIME ZONE NOT NULL, + created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(), + CONSTRAINT uq_index_cache_embedding_sets_key + UNIQUE (markdown_sha256, embedding_model, chunker_kind, chunker_version) + ); + """ + ) + + op.execute( + "CREATE INDEX IF NOT EXISTS ix_index_cache_embedding_sets_last_used_at " + "ON index_cache_embedding_sets(last_used_at);" + ) + op.execute( + "CREATE INDEX IF NOT EXISTS ix_index_cache_embedding_sets_created_at " + "ON index_cache_embedding_sets(created_at);" + ) + + +def downgrade() -> None: + op.execute("DROP INDEX IF EXISTS ix_index_cache_embedding_sets_created_at;") + op.execute("DROP INDEX IF EXISTS ix_index_cache_embedding_sets_last_used_at;") + op.execute("DROP TABLE IF EXISTS index_cache_embedding_sets;") diff --git a/surfsense_backend/app/db.py b/surfsense_backend/app/db.py index 97843d395..9ec13f4e2 100644 --- a/surfsense_backend/app/db.py +++ b/surfsense_backend/app/db.py @@ -2866,6 +2866,9 @@ from app.automations.persistence import ( # noqa: E402, F401 ) from app.etl_pipeline.cache.persistence.models import CachedParse # noqa: E402, F401 from app.file_storage.persistence import DocumentFile # noqa: E402, F401 +from app.indexing_pipeline.cache.persistence.models import ( # noqa: E402, F401 + CachedEmbeddingSet, +) from app.notifications.persistence import Notification # noqa: E402, F401 from app.podcasts.persistence import ( # noqa: E402, F401 Podcast, diff --git a/surfsense_backend/app/indexing_pipeline/cache/persistence/__init__.py b/surfsense_backend/app/indexing_pipeline/cache/persistence/__init__.py new file mode 100644 index 000000000..62cde0d05 --- /dev/null +++ b/surfsense_backend/app/indexing_pipeline/cache/persistence/__init__.py @@ -0,0 +1,11 @@ +"""Database access for cached embedding sets.""" + +from __future__ import annotations + +from .models import CachedEmbeddingSet +from .repository import CachedEmbeddingSetRepository + +__all__ = [ + "CachedEmbeddingSet", + "CachedEmbeddingSetRepository", +] diff --git a/surfsense_backend/app/indexing_pipeline/cache/persistence/models.py b/surfsense_backend/app/indexing_pipeline/cache/persistence/models.py new file mode 100644 index 000000000..e33e470f0 --- /dev/null +++ b/surfsense_backend/app/indexing_pipeline/cache/persistence/models.py @@ -0,0 +1,47 @@ +"""``index_cache_embedding_sets``: one reusable chunk+embedding set per markdown.""" + +from __future__ import annotations + +from sqlalchemy import ( + BigInteger, + Column, + DateTime, + Index, + Integer, + String, + UniqueConstraint, +) + +from app.db import BaseModel, TimestampMixin + + +class CachedEmbeddingSet(BaseModel, TimestampMixin): + __tablename__ = "index_cache_embedding_sets" + + # Key: markdown text + the recipe that turned it into vectors. + markdown_sha256 = Column(String(64), nullable=False) + embedding_model = Column(String(255), nullable=False) + embedding_dim = Column(Integer, nullable=False) + chunker_kind = Column(String(8), nullable=False) + chunker_version = Column(Integer, nullable=False) + + # Where the embedding blob lives (kept out of the row to stay small). + storage_backend = Column(String(32), nullable=False) + storage_key = Column(String, nullable=False) + size_bytes = Column(BigInteger, nullable=False) + chunk_count = Column(Integer, nullable=False, default=0, server_default="0") + + # Drives eviction (popularity + recency). + times_reused = Column(BigInteger, nullable=False, default=0, server_default="0") + last_used_at = Column(DateTime(timezone=True), nullable=False) + + __table_args__ = ( + UniqueConstraint( + "markdown_sha256", + "embedding_model", + "chunker_kind", + "chunker_version", + name="uq_index_cache_embedding_sets_key", + ), + Index("ix_index_cache_embedding_sets_last_used_at", "last_used_at"), + ) diff --git a/surfsense_backend/app/indexing_pipeline/cache/persistence/repository.py b/surfsense_backend/app/indexing_pipeline/cache/persistence/repository.py new file mode 100644 index 000000000..0bb0f8f23 --- /dev/null +++ b/surfsense_backend/app/indexing_pipeline/cache/persistence/repository.py @@ -0,0 +1,126 @@ +"""CRUD and eviction selectors for ``index_cache_embedding_sets`` (no business rules).""" + +from __future__ import annotations + +from datetime import UTC, datetime + +from sqlalchemy import delete, func, select, update +from sqlalchemy.dialects.postgresql import insert as pg_insert +from sqlalchemy.ext.asyncio import AsyncSession + +from app.etl_pipeline.cache.schemas import EvictionCandidate +from app.indexing_pipeline.cache.schemas import EmbeddingKey + +from .models import CachedEmbeddingSet + +_EVICTION_COLUMNS = ( + CachedEmbeddingSet.id, + CachedEmbeddingSet.storage_key, + CachedEmbeddingSet.size_bytes, + CachedEmbeddingSet.last_used_at, + CachedEmbeddingSet.times_reused, +) + + +def _as_eviction_candidate(row) -> EvictionCandidate: + return EvictionCandidate( + id=row.id, + storage_key=row.storage_key, + size_bytes=row.size_bytes, + last_used_at=row.last_used_at, + times_reused=row.times_reused, + ) + + +class CachedEmbeddingSetRepository: + def __init__(self, session: AsyncSession) -> None: + self._session = session + + async def get(self, key: EmbeddingKey) -> CachedEmbeddingSet | None: + result = await self._session.execute( + select(CachedEmbeddingSet).where( + CachedEmbeddingSet.markdown_sha256 == key.markdown_sha256, + CachedEmbeddingSet.embedding_model == key.embedding_model, + CachedEmbeddingSet.chunker_kind == key.chunker_kind, + CachedEmbeddingSet.chunker_version == key.chunker_version, + ) + ) + return result.scalars().first() + + async def insert( + self, + *, + key: EmbeddingKey, + storage_backend: str, + storage_key: str, + size_bytes: int, + chunk_count: int, + ) -> None: + # Concurrent writers embed identical markdown, so a lost race is harmless. + now = datetime.now(UTC) + await self._session.execute( + pg_insert(CachedEmbeddingSet) + .values( + markdown_sha256=key.markdown_sha256, + embedding_model=key.embedding_model, + embedding_dim=key.embedding_dim, + chunker_kind=key.chunker_kind, + chunker_version=key.chunker_version, + storage_backend=storage_backend, + storage_key=storage_key, + size_bytes=size_bytes, + chunk_count=chunk_count, + times_reused=0, + last_used_at=now, + created_at=now, + ) + .on_conflict_do_nothing(constraint="uq_index_cache_embedding_sets_key") + ) + await self._session.commit() + + async def mark_used(self, row_id: int) -> None: + await self._session.execute( + update(CachedEmbeddingSet) + .where(CachedEmbeddingSet.id == row_id) + .values( + times_reused=CachedEmbeddingSet.times_reused + 1, + last_used_at=datetime.now(UTC), + ) + ) + await self._session.commit() + + async def total_size_bytes(self) -> int: + result = await self._session.execute( + select(func.coalesce(func.sum(CachedEmbeddingSet.size_bytes), 0)) + ) + return int(result.scalar() or 0) + + async def select_expired( + self, *, cutoff: datetime, limit: int + ) -> list[EvictionCandidate]: + result = await self._session.execute( + select(*_EVICTION_COLUMNS) + .where(CachedEmbeddingSet.last_used_at < cutoff) + .order_by(CachedEmbeddingSet.last_used_at.asc()) + .limit(limit) + ) + return [_as_eviction_candidate(row) for row in result] + + async def select_coldest(self, *, limit: int) -> list[EvictionCandidate]: + result = await self._session.execute( + select(*_EVICTION_COLUMNS) + .order_by( + CachedEmbeddingSet.times_reused.asc(), + CachedEmbeddingSet.last_used_at.asc(), + ) + .limit(limit) + ) + return [_as_eviction_candidate(row) for row in result] + + async def delete_by_ids(self, ids: list[int]) -> None: + if not ids: + return + await self._session.execute( + delete(CachedEmbeddingSet).where(CachedEmbeddingSet.id.in_(ids)) + ) + await self._session.commit() From ad6da7c6afbbef3aca89100fca0caefde47e3c0d Mon Sep 17 00:00:00 2001 From: CREDO23 Date: Fri, 12 Jun 2026 16:48:01 +0200 Subject: [PATCH 076/212] feat(index-cache): add embedding blob store sharing the cache backend --- .../cache/storage/__init__.py | 9 +++++ .../cache/storage/embedding_store.py | 39 +++++++++++++++++++ .../cache/storage/object_keys.py | 12 ++++++ 3 files changed, 60 insertions(+) create mode 100644 surfsense_backend/app/indexing_pipeline/cache/storage/__init__.py create mode 100644 surfsense_backend/app/indexing_pipeline/cache/storage/embedding_store.py create mode 100644 surfsense_backend/app/indexing_pipeline/cache/storage/object_keys.py diff --git a/surfsense_backend/app/indexing_pipeline/cache/storage/__init__.py b/surfsense_backend/app/indexing_pipeline/cache/storage/__init__.py new file mode 100644 index 000000000..72b04c34d --- /dev/null +++ b/surfsense_backend/app/indexing_pipeline/cache/storage/__init__.py @@ -0,0 +1,9 @@ +"""Blob storage for cached embedding sets.""" + +from __future__ import annotations + +from .embedding_store import EmbeddingCacheStore + +__all__ = [ + "EmbeddingCacheStore", +] diff --git a/surfsense_backend/app/indexing_pipeline/cache/storage/embedding_store.py b/surfsense_backend/app/indexing_pipeline/cache/storage/embedding_store.py new file mode 100644 index 000000000..48835a12b --- /dev/null +++ b/surfsense_backend/app/indexing_pipeline/cache/storage/embedding_store.py @@ -0,0 +1,39 @@ +"""Read and write cached embedding blobs through the shared cache backend. + +The blob backend is shared with the ETL parse cache (same bucket / root), so +markdown and its embeddings live side by side; only the object prefix differs. +""" + +from __future__ import annotations + +from app.etl_pipeline.cache.storage.backend import resolve_cache_backend +from app.indexing_pipeline.cache.serialization import deserialize, serialize +from app.indexing_pipeline.cache.schemas import EmbeddingKey, EmbeddingSet +from app.indexing_pipeline.cache.storage.object_keys import build_embedding_object_key + +_EMBEDDING_CONTENT_TYPE = "application/octet-stream" + + +class EmbeddingCacheStore: + def __init__(self) -> None: + self._backend = resolve_cache_backend() + + @property + def backend_name(self) -> str: + return self._backend.backend_name + + async def save(self, key: EmbeddingKey, embedding_set: EmbeddingSet) -> tuple[str, int]: + """Persist the embedding set and return its storage key and byte size.""" + blob = serialize(embedding_set) + storage_key = build_embedding_object_key(key) + await self._backend.put( + storage_key, blob, content_type=_EMBEDDING_CONTENT_TYPE + ) + return storage_key, len(blob) + + async def load(self, storage_key: str) -> EmbeddingSet: + chunks = [chunk async for chunk in self._backend.open_stream(storage_key)] + return deserialize(b"".join(chunks)) + + async def delete(self, storage_key: str) -> None: + await self._backend.delete(storage_key) diff --git a/surfsense_backend/app/indexing_pipeline/cache/storage/object_keys.py b/surfsense_backend/app/indexing_pipeline/cache/storage/object_keys.py new file mode 100644 index 000000000..90e0b8957 --- /dev/null +++ b/surfsense_backend/app/indexing_pipeline/cache/storage/object_keys.py @@ -0,0 +1,12 @@ +"""Object keys for cached embedding sets, namespaced under a dedicated prefix.""" + +from __future__ import annotations + +from app.indexing_pipeline.cache.schemas import EmbeddingKey + +CACHE_PREFIX = "index_cache" + + +def build_embedding_object_key(key: EmbeddingKey) -> str: + # Content-addressed: identical markdown + recipe always map to the same key. + return f"{CACHE_PREFIX}/{key.markdown_sha256}/{key.object_suffix}" From daccd304ee8bce407495964ecb0ea070f1c22716 Mon Sep 17 00:00:00 2001 From: CREDO23 Date: Fri, 12 Jun 2026 16:48:10 +0200 Subject: [PATCH 077/212] feat(index-cache): add settings, eligibility, and config flags --- surfsense_backend/.env.example | 14 +++++++++ surfsense_backend/app/config/__init__.py | 11 +++++++ .../indexing_pipeline/cache/eligibility.py | 21 +++++++++++++ .../app/indexing_pipeline/cache/settings.py | 30 +++++++++++++++++++ 4 files changed, 76 insertions(+) create mode 100644 surfsense_backend/app/indexing_pipeline/cache/eligibility.py create mode 100644 surfsense_backend/app/indexing_pipeline/cache/settings.py diff --git a/surfsense_backend/.env.example b/surfsense_backend/.env.example index 1924756ce..03b8d9255 100644 --- a/surfsense_backend/.env.example +++ b/surfsense_backend/.env.example @@ -328,6 +328,20 @@ ETL_CACHE_ENABLED=false # ETL_CACHE_STORAGE_CONTAINER=surfsense-etl-cache # ETL_CACHE_STORAGE_LOCAL_PATH=/var/lib/surfsense/etl-cache +# Index Cache +# Reuse chunk+embedding output for identical markdown across workspaces (skips +# re-chunking and re-embedding). Blobs share the ETL_CACHE_STORAGE_* backend. +# Off by default. +INDEX_CACHE_ENABLED=false +# Bump to invalidate all cached embedding sets after a chunker change. +# INDEX_CACHE_CHUNKER_VERSION=1 +# Prune entries unused for this many days. +# INDEX_CACHE_TTL_DAYS=90 +# Soft cap on total cached embeddings; coldest entries are evicted past it. +# INDEX_CACHE_MAX_TOTAL_MB=5120 +# Rows deleted per eviction pass. +# INDEX_CACHE_EVICTION_BATCH=500 + # Daytona Sandbox (isolated code execution) # DAYTONA_SANDBOX_ENABLED=FALSE # DAYTONA_API_KEY=your-daytona-api-key diff --git a/surfsense_backend/app/config/__init__.py b/surfsense_backend/app/config/__init__.py index 525fe160d..0b6c05c39 100644 --- a/surfsense_backend/app/config/__init__.py +++ b/surfsense_backend/app/config/__init__.py @@ -964,6 +964,17 @@ class Config: ETL_CACHE_STORAGE_CONTAINER = os.getenv("ETL_CACHE_STORAGE_CONTAINER") ETL_CACHE_STORAGE_LOCAL_PATH = os.getenv("ETL_CACHE_STORAGE_LOCAL_PATH") + # Index cache: reuse chunk+embedding output for identical markdown across + # workspaces. Blobs share the ETL_CACHE_STORAGE_* backend. + INDEX_CACHE_ENABLED = ( + os.getenv("INDEX_CACHE_ENABLED", "false").strip().lower() == "true" + ) + # Bump to invalidate every cached embedding set after a chunker change. + INDEX_CACHE_CHUNKER_VERSION = int(os.getenv("INDEX_CACHE_CHUNKER_VERSION", "1")) + INDEX_CACHE_TTL_DAYS = int(os.getenv("INDEX_CACHE_TTL_DAYS", "90")) + INDEX_CACHE_MAX_TOTAL_MB = int(os.getenv("INDEX_CACHE_MAX_TOTAL_MB", "5120")) + INDEX_CACHE_EVICTION_BATCH = int(os.getenv("INDEX_CACHE_EVICTION_BATCH", "500")) + # Proxy provider selection. Maps to a ProxyProvider implementation registered # in app/utils/proxy/registry.py. Add new vendors there and switch via this var. PROXY_PROVIDER = os.getenv("PROXY_PROVIDER", "anonymous_proxies") diff --git a/surfsense_backend/app/indexing_pipeline/cache/eligibility.py b/surfsense_backend/app/indexing_pipeline/cache/eligibility.py new file mode 100644 index 000000000..7dbf79202 --- /dev/null +++ b/surfsense_backend/app/indexing_pipeline/cache/eligibility.py @@ -0,0 +1,21 @@ +"""Gating rule: may this document be served from / written to the index cache?""" + +from __future__ import annotations + + +def is_index_cacheable( + *, + cache_enabled: bool, + embedding_model: str | None, + embedding_dim: int | None, +) -> bool: + """Cache only when a concrete embedding model and dimension are configured. + + Without a model there is nothing to key against, and without a dimension the + blob's integrity guard cannot run -- both bypass the cache. + """ + if not cache_enabled: + return False + if not embedding_model: + return False + return bool(embedding_dim) diff --git a/surfsense_backend/app/indexing_pipeline/cache/settings.py b/surfsense_backend/app/indexing_pipeline/cache/settings.py new file mode 100644 index 000000000..2991c0980 --- /dev/null +++ b/surfsense_backend/app/indexing_pipeline/cache/settings.py @@ -0,0 +1,30 @@ +"""Index-cache configuration resolved from the central ``Config``. + +The blob backend is intentionally not configured here: it is shared with the ETL +parse cache (see ``ETL_CACHE_STORAGE_*``). +""" + +from __future__ import annotations + +from dataclasses import dataclass + + +@dataclass(frozen=True) +class IndexCacheSettings: + enabled: bool + chunker_version: int + ttl_days: int + max_total_bytes: int + eviction_batch: int + + +def load_index_cache_settings() -> IndexCacheSettings: + from app.config import config + + return IndexCacheSettings( + enabled=config.INDEX_CACHE_ENABLED, + chunker_version=config.INDEX_CACHE_CHUNKER_VERSION, + ttl_days=config.INDEX_CACHE_TTL_DAYS, + max_total_bytes=config.INDEX_CACHE_MAX_TOTAL_MB * 1024 * 1024, + eviction_batch=config.INDEX_CACHE_EVICTION_BATCH, + ) From 4d6378e031ff0de284867ebe3cbf36d8f6028099 Mon Sep 17 00:00:00 2001 From: CREDO23 Date: Fri, 12 Jun 2026 16:48:10 +0200 Subject: [PATCH 078/212] feat(observability): add index cache hit/miss and eviction metrics --- .../app/observability/metrics.py | 40 +++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/surfsense_backend/app/observability/metrics.py b/surfsense_backend/app/observability/metrics.py index 4751278a4..61b380722 100644 --- a/surfsense_backend/app/observability/metrics.py +++ b/surfsense_backend/app/observability/metrics.py @@ -305,6 +305,22 @@ def _etl_cache_evictions(): ) +@lru_cache(maxsize=1) +def _index_cache_lookups(): + return _get_meter().create_counter( + "surfsense.index.cache.lookups", + description="Count of index (chunk+embedding) cache lookups by outcome (hit/miss).", + ) + + +@lru_cache(maxsize=1) +def _index_cache_evictions(): + return _get_meter().create_counter( + "surfsense.index.cache.evictions", + description="Count of index cache entries evicted, by phase.", + ) + + @lru_cache(maxsize=1) def _celery_heartbeat_refreshes(): return _get_meter().create_counter( @@ -708,6 +724,28 @@ def record_etl_cache_eviction(count: int, *, phase: str) -> None: _add(_etl_cache_evictions(), count, {"phase": phase}) +def record_index_cache_lookup( + *, embedding_model: str | None, chunker_kind: str | None, outcome: str +) -> None: + """Record an index-cache lookup. ``outcome`` is ``hit`` or ``miss``.""" + _add( + _index_cache_lookups(), + 1, + { + "embedding.model": embedding_model or "unknown", + "chunker.kind": chunker_kind or "unknown", + "outcome": outcome, + }, + ) + + +def record_index_cache_eviction(count: int, *, phase: str) -> None: + """Record evicted entries. ``phase`` is ``ttl`` or ``size``.""" + if count <= 0: + return + _add(_index_cache_evictions(), count, {"phase": phase}) + + def record_celery_heartbeat_refresh(*, heartbeat_type: str) -> None: _add(_celery_heartbeat_refreshes(), 1, {"heartbeat.type": heartbeat_type}) @@ -908,6 +946,8 @@ __all__ = [ "record_etl_cache_lookup", "record_etl_extract_duration", "record_etl_extract_outcome", + "record_index_cache_eviction", + "record_index_cache_lookup", "record_indexing_document_duration", "record_indexing_document_outcome", "record_interrupt", From e8938c119bdda6f44ac78d48011e004a5ce02294 Mon Sep 17 00:00:00 2001 From: CREDO23 Date: Fri, 12 Jun 2026 16:48:10 +0200 Subject: [PATCH 079/212] feat(index-cache): add recall/remember service --- .../app/indexing_pipeline/cache/service.py | 51 +++++++++++++++++++ 1 file changed, 51 insertions(+) create mode 100644 surfsense_backend/app/indexing_pipeline/cache/service.py diff --git a/surfsense_backend/app/indexing_pipeline/cache/service.py b/surfsense_backend/app/indexing_pipeline/cache/service.py new file mode 100644 index 000000000..942ba7d51 --- /dev/null +++ b/surfsense_backend/app/indexing_pipeline/cache/service.py @@ -0,0 +1,51 @@ +"""Recall and remember embedding sets, coordinating the index and blob store.""" + +from __future__ import annotations + +import logging + +from sqlalchemy.ext.asyncio import AsyncSession + +from app.indexing_pipeline.cache.persistence import CachedEmbeddingSetRepository +from app.indexing_pipeline.cache.schemas import EmbeddingKey, EmbeddingSet +from app.indexing_pipeline.cache.storage import EmbeddingCacheStore + +logger = logging.getLogger(__name__) + + +class IndexCacheService: + def __init__(self, session: AsyncSession) -> None: + self._index = CachedEmbeddingSetRepository(session) + self._store = EmbeddingCacheStore() + + async def recall(self, key: EmbeddingKey) -> EmbeddingSet | None: + """Return the cached embedding set, or None on a miss.""" + row = await self._index.get(key) + if row is None: + return None + + try: + embedding_set = await self._store.load(row.storage_key) + except Exception: + # Index points at a blob that is gone; treat as a miss and re-embed. + logger.warning("Cache blob missing: %s", row.storage_key, exc_info=True) + return None + + if int(embedding_set.summary_embedding.shape[0]) != key.embedding_dim: + # A model swapped its dimension under a reused name; never serve it. + logger.warning("Cached embedding dimension mismatch: %s", row.storage_key) + return None + + await self._index.mark_used(row.id) + return embedding_set + + async def remember(self, key: EmbeddingKey, embedding_set: EmbeddingSet) -> None: + """Store a freshly embedded set for future reuse.""" + storage_key, size_bytes = await self._store.save(key, embedding_set) + await self._index.insert( + key=key, + storage_backend=self._store.backend_name, + storage_key=storage_key, + size_bytes=size_bytes, + chunk_count=embedding_set.chunk_count, + ) From 019aa7bf76dd1dd68e3f54321dbf21fcd9afe43a Mon Sep 17 00:00:00 2001 From: CREDO23 Date: Fri, 12 Jun 2026 16:48:18 +0200 Subject: [PATCH 080/212] feat(index-cache): serve chunk embeddings from cache during indexing --- .../app/indexing_pipeline/cache/__init__.py | 11 ++ .../cache/cached_indexing.py | 121 ++++++++++++++++++ .../indexing_pipeline_service.py | 27 +--- 3 files changed, 138 insertions(+), 21 deletions(-) create mode 100644 surfsense_backend/app/indexing_pipeline/cache/__init__.py create mode 100644 surfsense_backend/app/indexing_pipeline/cache/cached_indexing.py diff --git a/surfsense_backend/app/indexing_pipeline/cache/__init__.py b/surfsense_backend/app/indexing_pipeline/cache/__init__.py new file mode 100644 index 000000000..190b45f84 --- /dev/null +++ b/surfsense_backend/app/indexing_pipeline/cache/__init__.py @@ -0,0 +1,11 @@ +"""Content-addressed reuse of chunk+embedding output across workspaces.""" + +from __future__ import annotations + +from app.indexing_pipeline.cache.cached_indexing import build_chunk_embeddings +from app.indexing_pipeline.cache.service import IndexCacheService + +__all__ = [ + "IndexCacheService", + "build_chunk_embeddings", +] diff --git a/surfsense_backend/app/indexing_pipeline/cache/cached_indexing.py b/surfsense_backend/app/indexing_pipeline/cache/cached_indexing.py new file mode 100644 index 000000000..db07998a4 --- /dev/null +++ b/surfsense_backend/app/indexing_pipeline/cache/cached_indexing.py @@ -0,0 +1,121 @@ +"""Entry point: serve chunk embeddings from cache, embedding only on a miss. + +Embeddings are a pure function of the markdown, the embedding model, and the +chunker -- so identical markdown is chunked and embedded once and reused across +workspaces, even when it came from different sources. +""" + +from __future__ import annotations + +import asyncio +import hashlib +import logging + +import numpy as np + +from app.config import config +from app.indexing_pipeline.cache.eligibility import is_index_cacheable +from app.indexing_pipeline.cache.schemas import CachedChunk, EmbeddingKey, EmbeddingSet +from app.indexing_pipeline.cache.service import IndexCacheService +from app.indexing_pipeline.cache.settings import load_index_cache_settings +from app.indexing_pipeline.document_chunker import chunk_text, chunk_text_hybrid +from app.indexing_pipeline.document_embedder import embed_texts +from app.observability import metrics + +logger = logging.getLogger(__name__) + +ChunkPair = tuple[str, np.ndarray] + + +async def build_chunk_embeddings( + markdown: str, *, use_code_chunker: bool +) -> tuple[np.ndarray, list[ChunkPair]]: + """Return the document-level vector and ordered ``(chunk_text, vector)`` pairs. + + Drop-in for the inline chunk+embed step; reuses prior output when the same + markdown has already been embedded with the current model and chunker. + """ + settings = load_index_cache_settings() + chunker_kind = "code" if use_code_chunker else "hybrid" + embedding_dim = getattr(config.embedding_model_instance, "dimension", None) + + cacheable = is_index_cacheable( + cache_enabled=settings.enabled, + embedding_model=config.EMBEDDING_MODEL, + embedding_dim=embedding_dim, + ) + if not cacheable: + return await _compute(markdown, use_code_chunker=use_code_chunker) + + key = EmbeddingKey( + markdown_sha256=_hash_text(markdown), + embedding_model=config.EMBEDDING_MODEL, + embedding_dim=int(embedding_dim), + chunker_kind=chunker_kind, + chunker_version=settings.chunker_version, + ) + + cached = await _recall(key) + if cached is not None: + metrics.record_index_cache_lookup( + embedding_model=key.embedding_model, chunker_kind=chunker_kind, outcome="hit" + ) + logger.debug("Index cache hit for %s", key.markdown_sha256) + return cached.summary_embedding, [(c.text, c.embedding) for c in cached.chunks] + + metrics.record_index_cache_lookup( + embedding_model=key.embedding_model, chunker_kind=chunker_kind, outcome="miss" + ) + summary_embedding, chunk_pairs = await _compute( + markdown, use_code_chunker=use_code_chunker + ) + await _remember(key, summary_embedding, chunk_pairs) + return summary_embedding, chunk_pairs + + +async def _compute( + markdown: str, *, use_code_chunker: bool +) -> tuple[np.ndarray, list[ChunkPair]]: + if use_code_chunker: + chunk_texts = await asyncio.to_thread( + chunk_text, markdown, use_code_chunker=True + ) + else: + # Table-aware hybrid chunker keeps Markdown tables intact (issue #1334). + chunk_texts = await asyncio.to_thread(chunk_text_hybrid, markdown) + + embeddings = await asyncio.to_thread(embed_texts, [markdown, *chunk_texts]) + summary_embedding, *chunk_embeddings = embeddings + return summary_embedding, list(zip(chunk_texts, chunk_embeddings, strict=False)) + + +async def _recall(key: EmbeddingKey) -> EmbeddingSet | None: + # Caching is best-effort: any failure falls through to a normal embed. + try: + from app.tasks.celery_tasks import get_celery_session_maker + + async with get_celery_session_maker()() as session: + return await IndexCacheService(session).recall(key) + except Exception: + logger.warning("Index cache recall failed; embedding fresh", exc_info=True) + return None + + +async def _remember( + key: EmbeddingKey, summary_embedding: np.ndarray, chunk_pairs: list[ChunkPair] +) -> None: + try: + from app.tasks.celery_tasks import get_celery_session_maker + + embedding_set = EmbeddingSet( + summary_embedding=summary_embedding, + chunks=[CachedChunk(text=text, embedding=vec) for text, vec in chunk_pairs], + ) + async with get_celery_session_maker()() as session: + await IndexCacheService(session).remember(key, embedding_set) + except Exception: + logger.warning("Index cache write failed; result not cached", exc_info=True) + + +def _hash_text(text: str) -> str: + return hashlib.sha256(text.encode("utf-8")).hexdigest() diff --git a/surfsense_backend/app/indexing_pipeline/indexing_pipeline_service.py b/surfsense_backend/app/indexing_pipeline/indexing_pipeline_service.py index 67a6778e0..271b3ee03 100644 --- a/surfsense_backend/app/indexing_pipeline/indexing_pipeline_service.py +++ b/surfsense_backend/app/indexing_pipeline/indexing_pipeline_service.py @@ -19,9 +19,8 @@ from app.db import ( DocumentStatus, DocumentType, ) +from app.indexing_pipeline.cache import build_chunk_embeddings from app.indexing_pipeline.connector_document import ConnectorDocument -from app.indexing_pipeline.document_chunker import chunk_text, chunk_text_hybrid -from app.indexing_pipeline.document_embedder import embed_texts from app.indexing_pipeline.document_hashing import ( compute_content_hash, compute_identifier_hash, @@ -385,27 +384,13 @@ class IndexingPipelineService: ) t_step = time.perf_counter() - if connector_doc.should_use_code_chunker: - chunk_texts = await asyncio.to_thread( - chunk_text, - connector_doc.source_markdown, - use_code_chunker=True, - ) - else: - # Use the table-aware hybrid chunker so Markdown tables are not - # split mid-row (see issue #1334). - chunk_texts = await asyncio.to_thread( - chunk_text_hybrid, - connector_doc.source_markdown, - ) - - texts_to_embed = [content, *chunk_texts] - embeddings = await asyncio.to_thread(embed_texts, texts_to_embed) - summary_embedding, *chunk_embeddings = embeddings + summary_embedding, chunk_pairs = await build_chunk_embeddings( + content, + use_code_chunker=connector_doc.should_use_code_chunker, + ) chunks = [ - Chunk(content=text, embedding=emb) - for text, emb in zip(chunk_texts, chunk_embeddings, strict=False) + Chunk(content=text, embedding=emb) for text, emb in chunk_pairs ] perf.info( "[indexing] chunk+embed doc=%d chunks=%d in %.3fs", From 4e4f7f34faaa40102a8201a8e19c335c8b5dcd27 Mon Sep 17 00:00:00 2001 From: CREDO23 Date: Fri, 12 Jun 2026 16:48:18 +0200 Subject: [PATCH 081/212] feat(index-cache): add TTL/size eviction task and daily schedule --- surfsense_backend/app/celery_app.py | 7 ++ .../cache/eviction/__init__.py | 9 +++ .../indexing_pipeline/cache/eviction/task.py | 68 +++++++++++++++++++ 3 files changed, 84 insertions(+) create mode 100644 surfsense_backend/app/indexing_pipeline/cache/eviction/__init__.py create mode 100644 surfsense_backend/app/indexing_pipeline/cache/eviction/task.py diff --git a/surfsense_backend/app/celery_app.py b/surfsense_backend/app/celery_app.py index 413522189..38fb12a32 100644 --- a/surfsense_backend/app/celery_app.py +++ b/surfsense_backend/app/celery_app.py @@ -193,6 +193,7 @@ celery_app = Celery( "app.tasks.celery_tasks.auto_reload_task", "app.tasks.celery_tasks.gateway_tasks", "app.etl_pipeline.cache.eviction.task", + "app.indexing_pipeline.cache.eviction.task", "app.automations.tasks.execute_run", "app.automations.triggers.builtin.schedule.selector", "app.automations.triggers.builtin.event.selector", @@ -313,6 +314,12 @@ celery_app.conf.beat_schedule = { "schedule": crontab(hour="4", minute="0"), "options": {"expires": 600}, }, + # Prune the index cache (chunk+embedding sets) once daily, off-peak. + "evict-index-cache": { + "task": "evict_index_cache", + "schedule": crontab(hour="4", minute="30"), + "options": {"expires": 600}, + }, # Fire due automation schedule triggers (Beat entry owned by the schedule # trigger; see app.automations.triggers.builtin.schedule.source). **SCHEDULE_BEAT_SCHEDULE, diff --git a/surfsense_backend/app/indexing_pipeline/cache/eviction/__init__.py b/surfsense_backend/app/indexing_pipeline/cache/eviction/__init__.py new file mode 100644 index 000000000..de4df784e --- /dev/null +++ b/surfsense_backend/app/indexing_pipeline/cache/eviction/__init__.py @@ -0,0 +1,9 @@ +"""Background pruning of the index cache by age and size budget.""" + +from __future__ import annotations + +from .task import evict_index_cache_task + +__all__ = [ + "evict_index_cache_task", +] diff --git a/surfsense_backend/app/indexing_pipeline/cache/eviction/task.py b/surfsense_backend/app/indexing_pipeline/cache/eviction/task.py new file mode 100644 index 000000000..ab6885bca --- /dev/null +++ b/surfsense_backend/app/indexing_pipeline/cache/eviction/task.py @@ -0,0 +1,68 @@ +"""Celery task that prunes the index cache by TTL, then by size budget.""" + +from __future__ import annotations + +import contextlib +import logging +from datetime import UTC, datetime, timedelta + +from app.celery_app import celery_app +from app.etl_pipeline.cache.eviction.policy import select_over_budget +from app.etl_pipeline.cache.schemas import EvictionCandidate +from app.indexing_pipeline.cache.persistence import CachedEmbeddingSetRepository +from app.indexing_pipeline.cache.settings import load_index_cache_settings +from app.indexing_pipeline.cache.storage import EmbeddingCacheStore +from app.observability import metrics +from app.tasks.celery_tasks import get_celery_session_maker, run_async_celery_task + +logger = logging.getLogger(__name__) + + +@celery_app.task(name="evict_index_cache") +def evict_index_cache_task(): + return run_async_celery_task(_evict) + + +async def _evict() -> None: + """Expire stale entries, then shed the coldest overflow only if still over budget.""" + settings = load_index_cache_settings() + if not settings.enabled: + return + + store = EmbeddingCacheStore() + async with get_celery_session_maker()() as session: + index = CachedEmbeddingSetRepository(session) + + cutoff = datetime.now(UTC) - timedelta(days=settings.ttl_days) + expired = await index.select_expired( + cutoff=cutoff, limit=settings.eviction_batch + ) + await _drop(index, store, expired, phase="ttl") + + total = await index.total_size_bytes() + if total > settings.max_total_bytes: + coldest = await index.select_coldest(limit=settings.eviction_batch) + over_budget = select_over_budget( + coldest, + current_total_bytes=total, + max_total_bytes=settings.max_total_bytes, + ) + await _drop(index, store, over_budget, phase="size") + + +async def _drop( + index: CachedEmbeddingSetRepository, + store: EmbeddingCacheStore, + candidates: list[EvictionCandidate], + *, + phase: str, +) -> None: + if not candidates: + return + for candidate in candidates: + # Drop the index row even if the blob delete fails (orphan blob is harmless). + with contextlib.suppress(Exception): + await store.delete(candidate.storage_key) + await index.delete_by_ids([candidate.id for candidate in candidates]) + metrics.record_index_cache_eviction(len(candidates), phase=phase) + logger.info("Evicted %d cached embedding sets (%s)", len(candidates), phase) From 8cf578d965657bbf5e7f9e2f1513ed19d8b8622c Mon Sep 17 00:00:00 2001 From: CREDO23 Date: Fri, 12 Jun 2026 16:48:18 +0200 Subject: [PATCH 082/212] test(index-cache): add unit tests and repoint embed/chunk patch targets --- .../tests/e2e/fakes/embeddings.py | 4 +- .../tests/integration/conftest.py | 8 +-- .../integration/document_upload/conftest.py | 4 +- .../adapters/test_file_upload_adapter.py | 2 +- .../unit/indexing_pipeline/cache/conftest.py | 28 ++++++++++ .../cache/test_eligibility.py | 28 ++++++++++ .../cache/test_embedding_key.py | 31 +++++++++++ .../cache/test_serialization.py | 52 +++++++++++++++++++ .../test_index_batch_parallel.py | 10 ++-- 9 files changed, 153 insertions(+), 14 deletions(-) create mode 100644 surfsense_backend/tests/unit/indexing_pipeline/cache/conftest.py create mode 100644 surfsense_backend/tests/unit/indexing_pipeline/cache/test_eligibility.py create mode 100644 surfsense_backend/tests/unit/indexing_pipeline/cache/test_embedding_key.py create mode 100644 surfsense_backend/tests/unit/indexing_pipeline/cache/test_serialization.py diff --git a/surfsense_backend/tests/e2e/fakes/embeddings.py b/surfsense_backend/tests/e2e/fakes/embeddings.py index ab9e24df9..9a01fb84b 100644 --- a/surfsense_backend/tests/e2e/fakes/embeddings.py +++ b/surfsense_backend/tests/e2e/fakes/embeddings.py @@ -57,9 +57,9 @@ def install(patches: list[Any]) -> None: # Consumers that did `from app.utils.document_converters import embed_text/texts` ("app.indexing_pipeline.document_embedder.embed_text", fake_embed_text), ("app.indexing_pipeline.document_embedder.embed_texts", fake_embed_texts), - # Pipeline service binding (the actual call site for indexing.index) + # Index-cache facade binding (the actual call site for indexing.index) ( - "app.indexing_pipeline.indexing_pipeline_service.embed_texts", + "app.indexing_pipeline.cache.cached_indexing.embed_texts", fake_embed_texts, ), ] diff --git a/surfsense_backend/tests/integration/conftest.py b/surfsense_backend/tests/integration/conftest.py index 19f8e3d0a..8457047ec 100644 --- a/surfsense_backend/tests/integration/conftest.py +++ b/surfsense_backend/tests/integration/conftest.py @@ -127,7 +127,7 @@ async def db_search_space(db_session: AsyncSession, db_user: User) -> SearchSpac def patched_embed_texts(monkeypatch) -> MagicMock: mock = MagicMock(side_effect=lambda texts: [[0.1] * _EMBEDDING_DIM for _ in texts]) monkeypatch.setattr( - "app.indexing_pipeline.indexing_pipeline_service.embed_texts", + "app.indexing_pipeline.cache.cached_indexing.embed_texts", mock, ) return mock @@ -137,7 +137,7 @@ def patched_embed_texts(monkeypatch) -> MagicMock: def patched_embed_texts_raises(monkeypatch) -> MagicMock: mock = MagicMock(side_effect=RuntimeError("Embedding unavailable")) monkeypatch.setattr( - "app.indexing_pipeline.indexing_pipeline_service.embed_texts", + "app.indexing_pipeline.cache.cached_indexing.embed_texts", mock, ) return mock @@ -147,11 +147,11 @@ def patched_embed_texts_raises(monkeypatch) -> MagicMock: def patched_chunk_text(monkeypatch) -> MagicMock: mock = MagicMock(return_value=["Test chunk content."]) monkeypatch.setattr( - "app.indexing_pipeline.indexing_pipeline_service.chunk_text", + "app.indexing_pipeline.cache.cached_indexing.chunk_text", mock, ) monkeypatch.setattr( - "app.indexing_pipeline.indexing_pipeline_service.chunk_text_hybrid", + "app.indexing_pipeline.cache.cached_indexing.chunk_text_hybrid", mock, ) return mock diff --git a/surfsense_backend/tests/integration/document_upload/conftest.py b/surfsense_backend/tests/integration/document_upload/conftest.py index 812140be3..bd889360f 100644 --- a/surfsense_backend/tests/integration/document_upload/conftest.py +++ b/surfsense_backend/tests/integration/document_upload/conftest.py @@ -283,11 +283,11 @@ async def credits(): def _mock_external_apis(monkeypatch): """Mock LLM, embedding, and chunking — these are external API boundaries.""" monkeypatch.setattr( - "app.indexing_pipeline.indexing_pipeline_service.embed_texts", + "app.indexing_pipeline.cache.cached_indexing.embed_texts", MagicMock(side_effect=lambda texts: [[0.1] * _EMBEDDING_DIM for _ in texts]), ) monkeypatch.setattr( - "app.indexing_pipeline.indexing_pipeline_service.chunk_text", + "app.indexing_pipeline.cache.cached_indexing.chunk_text", MagicMock(return_value=["Test chunk content."]), ) diff --git a/surfsense_backend/tests/integration/indexing_pipeline/adapters/test_file_upload_adapter.py b/surfsense_backend/tests/integration/indexing_pipeline/adapters/test_file_upload_adapter.py index 311716052..814129c8d 100644 --- a/surfsense_backend/tests/integration/indexing_pipeline/adapters/test_file_upload_adapter.py +++ b/surfsense_backend/tests/integration/indexing_pipeline/adapters/test_file_upload_adapter.py @@ -177,7 +177,7 @@ async def test_reindex_sets_status_ready(db_session, db_search_space, db_user, m async def test_reindex_replaces_chunks(db_session, db_search_space, db_user, mocker): """Reindexing replaces old chunks with new content rather than appending.""" mocker.patch( - "app.indexing_pipeline.indexing_pipeline_service.chunk_text_hybrid", + "app.indexing_pipeline.cache.cached_indexing.chunk_text_hybrid", side_effect=[["Original chunk."], ["Updated chunk."]], ) diff --git a/surfsense_backend/tests/unit/indexing_pipeline/cache/conftest.py b/surfsense_backend/tests/unit/indexing_pipeline/cache/conftest.py new file mode 100644 index 000000000..081dddaa7 --- /dev/null +++ b/surfsense_backend/tests/unit/indexing_pipeline/cache/conftest.py @@ -0,0 +1,28 @@ +"""Stub the cache package __init__s so unit tests import only pure leaf modules. + +The real ``cache``/``storage``/``eviction``/``persistence`` __init__s eagerly +import the facade, file storage, Celery, and ``app.db`` -- none of which a pure +unit test should need. Turning those packages into bare namespace packages lets +``from app.indexing_pipeline.cache. import ...`` resolve the leaf module +without running the heavy __init__. ``schemas`` is left real (it is pure). +""" + +import sys +import types +from pathlib import Path + +_CACHE_DIR = Path(__file__).resolve().parents[4] / "app" / "indexing_pipeline" / "cache" + + +def _stub_namespace_package(dotted: str, fs_dir: Path) -> None: + if dotted in sys.modules: + return + module = types.ModuleType(dotted) + module.__path__ = [str(fs_dir)] + module.__package__ = dotted + sys.modules[dotted] = module + + +_stub_namespace_package("app.indexing_pipeline.cache", _CACHE_DIR) +_stub_namespace_package("app.indexing_pipeline.cache.storage", _CACHE_DIR / "storage") +_stub_namespace_package("app.indexing_pipeline.cache.eviction", _CACHE_DIR / "eviction") diff --git a/surfsense_backend/tests/unit/indexing_pipeline/cache/test_eligibility.py b/surfsense_backend/tests/unit/indexing_pipeline/cache/test_eligibility.py new file mode 100644 index 000000000..780a6c536 --- /dev/null +++ b/surfsense_backend/tests/unit/indexing_pipeline/cache/test_eligibility.py @@ -0,0 +1,28 @@ +from app.indexing_pipeline.cache.eligibility import is_index_cacheable + + +def test_disabled_cache_is_never_cacheable(): + assert not is_index_cacheable( + cache_enabled=False, embedding_model="m", embedding_dim=384 + ) + + +def test_missing_model_is_not_cacheable(): + assert not is_index_cacheable( + cache_enabled=True, embedding_model=None, embedding_dim=384 + ) + + +def test_missing_dimension_is_not_cacheable(): + assert not is_index_cacheable( + cache_enabled=True, embedding_model="m", embedding_dim=None + ) + assert not is_index_cacheable( + cache_enabled=True, embedding_model="m", embedding_dim=0 + ) + + +def test_enabled_with_model_and_dim_is_cacheable(): + assert is_index_cacheable( + cache_enabled=True, embedding_model="m", embedding_dim=384 + ) diff --git a/surfsense_backend/tests/unit/indexing_pipeline/cache/test_embedding_key.py b/surfsense_backend/tests/unit/indexing_pipeline/cache/test_embedding_key.py new file mode 100644 index 000000000..ce9c8672d --- /dev/null +++ b/surfsense_backend/tests/unit/indexing_pipeline/cache/test_embedding_key.py @@ -0,0 +1,31 @@ +from app.indexing_pipeline.cache.schemas import EmbeddingKey + + +def _key(**overrides) -> EmbeddingKey: + base = { + "markdown_sha256": "a" * 64, + "embedding_model": "openai://text-embedding-3-small", + "embedding_dim": 1536, + "chunker_kind": "hybrid", + "chunker_version": 1, + } + base.update(overrides) + return EmbeddingKey(**base) + + +def test_object_suffix_is_stable(): + assert _key().object_suffix == _key().object_suffix + + +def test_object_suffix_differs_by_model(): + assert _key().object_suffix != _key(embedding_model="local/minilm").object_suffix + + +def test_object_suffix_differs_by_chunker_kind_and_version(): + assert _key().object_suffix != _key(chunker_kind="code").object_suffix + assert _key().object_suffix != _key(chunker_version=2).object_suffix + + +def test_object_suffix_encodes_kind_and_version(): + suffix = _key(chunker_kind="code", chunker_version=3).object_suffix + assert suffix.endswith(".code.v3.emb") diff --git a/surfsense_backend/tests/unit/indexing_pipeline/cache/test_serialization.py b/surfsense_backend/tests/unit/indexing_pipeline/cache/test_serialization.py new file mode 100644 index 000000000..8db87bf1b --- /dev/null +++ b/surfsense_backend/tests/unit/indexing_pipeline/cache/test_serialization.py @@ -0,0 +1,52 @@ +import numpy as np +import pytest + +from app.indexing_pipeline.cache.schemas import CachedChunk, EmbeddingSet +from app.indexing_pipeline.cache.serialization import deserialize, serialize + + +def _make_set(dim: int, n_chunks: int) -> EmbeddingSet: + rng = np.random.default_rng(0) + return EmbeddingSet( + summary_embedding=rng.random(dim, dtype=np.float64), + chunks=[ + CachedChunk(text=f"chunk {i}\nwith newline", embedding=rng.random(dim)) + for i in range(n_chunks) + ], + ) + + +def test_round_trip_preserves_texts_and_vectors(): + original = _make_set(dim=8, n_chunks=3) + + restored = deserialize(serialize(original)) + + assert [c.text for c in restored.chunks] == [c.text for c in original.chunks] + assert restored.chunk_count == 3 + assert np.allclose(restored.summary_embedding, original.summary_embedding, atol=1e-6) + for got, want in zip(restored.chunks, original.chunks, strict=True): + assert np.allclose(got.embedding, want.embedding, atol=1e-6) + + +def test_round_trip_with_no_chunks(): + original = _make_set(dim=4, n_chunks=0) + + restored = deserialize(serialize(original)) + + assert restored.chunk_count == 0 + assert restored.summary_embedding.shape[0] == 4 + + +def test_serialize_rejects_mismatched_dimensions(): + bad = EmbeddingSet( + summary_embedding=np.zeros(4, dtype=np.float32), + chunks=[CachedChunk(text="x", embedding=np.zeros(8, dtype=np.float32))], + ) + + with pytest.raises(ValueError): + serialize(bad) + + +def test_deserialize_rejects_foreign_blob(): + with pytest.raises(ValueError): + deserialize(b"not-a-surfsense-blob") diff --git a/surfsense_backend/tests/unit/indexing_pipeline/test_index_batch_parallel.py b/surfsense_backend/tests/unit/indexing_pipeline/test_index_batch_parallel.py index 3a1b77d90..252310061 100644 --- a/surfsense_backend/tests/unit/indexing_pipeline/test_index_batch_parallel.py +++ b/surfsense_backend/tests/unit/indexing_pipeline/test_index_batch_parallel.py @@ -54,7 +54,7 @@ async def test_index_calls_embed_and_chunk_via_to_thread( mock_chunk_hybrid = MagicMock(return_value=["chunk1"]) mock_chunk_hybrid.__name__ = "chunk_text_hybrid" monkeypatch.setattr( - "app.indexing_pipeline.indexing_pipeline_service.chunk_text_hybrid", + "app.indexing_pipeline.cache.cached_indexing.chunk_text_hybrid", mock_chunk_hybrid, ) mock_embed = MagicMock( @@ -62,7 +62,7 @@ async def test_index_calls_embed_and_chunk_via_to_thread( ) mock_embed.__name__ = "embed_texts" monkeypatch.setattr( - "app.indexing_pipeline.indexing_pipeline_service.embed_texts", + "app.indexing_pipeline.cache.cached_indexing.embed_texts", mock_embed, ) # Bypass set_committed_value, which requires a real ORM instance (not MagicMock). @@ -102,17 +102,17 @@ async def test_non_code_documents_use_hybrid_chunker( mock_chunk_hybrid = MagicMock(return_value=["chunk1"]) mock_chunk_hybrid.__name__ = "chunk_text_hybrid" monkeypatch.setattr( - "app.indexing_pipeline.indexing_pipeline_service.chunk_text_hybrid", + "app.indexing_pipeline.cache.cached_indexing.chunk_text_hybrid", mock_chunk_hybrid, ) mock_chunk_code = MagicMock(return_value=["chunk1"]) mock_chunk_code.__name__ = "chunk_text" monkeypatch.setattr( - "app.indexing_pipeline.indexing_pipeline_service.chunk_text", + "app.indexing_pipeline.cache.cached_indexing.chunk_text", mock_chunk_code, ) monkeypatch.setattr( - "app.indexing_pipeline.indexing_pipeline_service.embed_texts", + "app.indexing_pipeline.cache.cached_indexing.embed_texts", MagicMock(side_effect=lambda texts: [[0.1] * _EMBEDDING_DIM for _ in texts]), ) monkeypatch.setattr( From 91d947ff79f6952d3d0c40217e118d10e09bfc8d Mon Sep 17 00:00:00 2001 From: CREDO23 Date: Fri, 12 Jun 2026 17:00:01 +0200 Subject: [PATCH 083/212] refactor(embedding-cache): rename index cache to embedding cache The cached payload is the indexing pipeline's embeddings (markdown is chunked then embedded), so "embedding cache" names the expensive output directly and removes the "index" ambiguity (DB index vs vector index vs indexing phase). Renames the service, settings, eligibility, eviction task, metrics, config flags (INDEX_CACHE_* -> EMBEDDING_CACHE_*), object prefix, and the table (index_cache_embedding_sets -> embedding_cache_sets) with its constraint and indexes. Migration 161 renamed accordingly. --- surfsense_backend/.env.example | 12 ++++----- ...ets.py => 161_add_embedding_cache_sets.py} | 20 +++++++------- surfsense_backend/app/celery_app.py | 6 ++--- surfsense_backend/app/config/__init__.py | 18 ++++++++----- .../app/indexing_pipeline/cache/__init__.py | 4 +-- .../cache/cached_indexing.py | 24 ++++++++--------- .../indexing_pipeline/cache/eligibility.py | 4 +-- .../cache/eviction/__init__.py | 6 ++--- .../indexing_pipeline/cache/eviction/task.py | 12 ++++----- .../cache/persistence/models.py | 8 +++--- .../cache/persistence/repository.py | 4 +-- .../cache/schemas/__init__.py | 2 +- .../app/indexing_pipeline/cache/service.py | 2 +- .../app/indexing_pipeline/cache/settings.py | 18 ++++++------- .../cache/storage/embedding_store.py | 2 +- .../cache/storage/object_keys.py | 2 +- .../app/observability/metrics.py | 26 +++++++++---------- .../cache/test_eligibility.py | 12 ++++----- 18 files changed, 93 insertions(+), 89 deletions(-) rename surfsense_backend/alembic/versions/{161_add_index_cache_embedding_sets.py => 161_add_embedding_cache_sets.py} (61%) diff --git a/surfsense_backend/.env.example b/surfsense_backend/.env.example index 03b8d9255..ac289c5a6 100644 --- a/surfsense_backend/.env.example +++ b/surfsense_backend/.env.example @@ -328,19 +328,19 @@ ETL_CACHE_ENABLED=false # ETL_CACHE_STORAGE_CONTAINER=surfsense-etl-cache # ETL_CACHE_STORAGE_LOCAL_PATH=/var/lib/surfsense/etl-cache -# Index Cache +# Embedding Cache # Reuse chunk+embedding output for identical markdown across workspaces (skips # re-chunking and re-embedding). Blobs share the ETL_CACHE_STORAGE_* backend. # Off by default. -INDEX_CACHE_ENABLED=false +EMBEDDING_CACHE_ENABLED=false # Bump to invalidate all cached embedding sets after a chunker change. -# INDEX_CACHE_CHUNKER_VERSION=1 +# EMBEDDING_CACHE_CHUNKER_VERSION=1 # Prune entries unused for this many days. -# INDEX_CACHE_TTL_DAYS=90 +# EMBEDDING_CACHE_TTL_DAYS=90 # Soft cap on total cached embeddings; coldest entries are evicted past it. -# INDEX_CACHE_MAX_TOTAL_MB=5120 +# EMBEDDING_CACHE_MAX_TOTAL_MB=5120 # Rows deleted per eviction pass. -# INDEX_CACHE_EVICTION_BATCH=500 +# EMBEDDING_CACHE_EVICTION_BATCH=500 # Daytona Sandbox (isolated code execution) # DAYTONA_SANDBOX_ENABLED=FALSE diff --git a/surfsense_backend/alembic/versions/161_add_index_cache_embedding_sets.py b/surfsense_backend/alembic/versions/161_add_embedding_cache_sets.py similarity index 61% rename from surfsense_backend/alembic/versions/161_add_index_cache_embedding_sets.py rename to surfsense_backend/alembic/versions/161_add_embedding_cache_sets.py index 8441dcf6e..70fb8b57a 100644 --- a/surfsense_backend/alembic/versions/161_add_index_cache_embedding_sets.py +++ b/surfsense_backend/alembic/versions/161_add_embedding_cache_sets.py @@ -1,4 +1,4 @@ -"""add index_cache_embedding_sets table for content-addressed embedding reuse +"""add embedding_cache_sets table for content-addressed embedding reuse Revision ID: 161 Revises: 160 @@ -17,7 +17,7 @@ depends_on: str | Sequence[str] | None = None def upgrade() -> None: op.execute( """ - CREATE TABLE IF NOT EXISTS index_cache_embedding_sets ( + CREATE TABLE IF NOT EXISTS embedding_cache_sets ( id SERIAL PRIMARY KEY, markdown_sha256 VARCHAR(64) NOT NULL, embedding_model VARCHAR(255) NOT NULL, @@ -31,23 +31,23 @@ def upgrade() -> None: times_reused BIGINT NOT NULL DEFAULT 0, last_used_at TIMESTAMP WITH TIME ZONE NOT NULL, created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(), - CONSTRAINT uq_index_cache_embedding_sets_key + CONSTRAINT uq_embedding_cache_sets_key UNIQUE (markdown_sha256, embedding_model, chunker_kind, chunker_version) ); """ ) op.execute( - "CREATE INDEX IF NOT EXISTS ix_index_cache_embedding_sets_last_used_at " - "ON index_cache_embedding_sets(last_used_at);" + "CREATE INDEX IF NOT EXISTS ix_embedding_cache_sets_last_used_at " + "ON embedding_cache_sets(last_used_at);" ) op.execute( - "CREATE INDEX IF NOT EXISTS ix_index_cache_embedding_sets_created_at " - "ON index_cache_embedding_sets(created_at);" + "CREATE INDEX IF NOT EXISTS ix_embedding_cache_sets_created_at " + "ON embedding_cache_sets(created_at);" ) def downgrade() -> None: - op.execute("DROP INDEX IF EXISTS ix_index_cache_embedding_sets_created_at;") - op.execute("DROP INDEX IF EXISTS ix_index_cache_embedding_sets_last_used_at;") - op.execute("DROP TABLE IF EXISTS index_cache_embedding_sets;") + op.execute("DROP INDEX IF EXISTS ix_embedding_cache_sets_created_at;") + op.execute("DROP INDEX IF EXISTS ix_embedding_cache_sets_last_used_at;") + op.execute("DROP TABLE IF EXISTS embedding_cache_sets;") diff --git a/surfsense_backend/app/celery_app.py b/surfsense_backend/app/celery_app.py index 38fb12a32..9704c0312 100644 --- a/surfsense_backend/app/celery_app.py +++ b/surfsense_backend/app/celery_app.py @@ -314,9 +314,9 @@ celery_app.conf.beat_schedule = { "schedule": crontab(hour="4", minute="0"), "options": {"expires": 600}, }, - # Prune the index cache (chunk+embedding sets) once daily, off-peak. - "evict-index-cache": { - "task": "evict_index_cache", + # Prune the embedding cache (chunk+embedding sets) once daily, off-peak. + "evict-embedding-cache": { + "task": "evict_embedding_cache", "schedule": crontab(hour="4", minute="30"), "options": {"expires": 600}, }, diff --git a/surfsense_backend/app/config/__init__.py b/surfsense_backend/app/config/__init__.py index 0b6c05c39..549252cec 100644 --- a/surfsense_backend/app/config/__init__.py +++ b/surfsense_backend/app/config/__init__.py @@ -964,16 +964,20 @@ class Config: ETL_CACHE_STORAGE_CONTAINER = os.getenv("ETL_CACHE_STORAGE_CONTAINER") ETL_CACHE_STORAGE_LOCAL_PATH = os.getenv("ETL_CACHE_STORAGE_LOCAL_PATH") - # Index cache: reuse chunk+embedding output for identical markdown across + # Embedding cache: reuse chunk+embedding output for identical markdown across # workspaces. Blobs share the ETL_CACHE_STORAGE_* backend. - INDEX_CACHE_ENABLED = ( - os.getenv("INDEX_CACHE_ENABLED", "false").strip().lower() == "true" + EMBEDDING_CACHE_ENABLED = ( + os.getenv("EMBEDDING_CACHE_ENABLED", "false").strip().lower() == "true" ) # Bump to invalidate every cached embedding set after a chunker change. - INDEX_CACHE_CHUNKER_VERSION = int(os.getenv("INDEX_CACHE_CHUNKER_VERSION", "1")) - INDEX_CACHE_TTL_DAYS = int(os.getenv("INDEX_CACHE_TTL_DAYS", "90")) - INDEX_CACHE_MAX_TOTAL_MB = int(os.getenv("INDEX_CACHE_MAX_TOTAL_MB", "5120")) - INDEX_CACHE_EVICTION_BATCH = int(os.getenv("INDEX_CACHE_EVICTION_BATCH", "500")) + EMBEDDING_CACHE_CHUNKER_VERSION = int( + os.getenv("EMBEDDING_CACHE_CHUNKER_VERSION", "1") + ) + EMBEDDING_CACHE_TTL_DAYS = int(os.getenv("EMBEDDING_CACHE_TTL_DAYS", "90")) + EMBEDDING_CACHE_MAX_TOTAL_MB = int(os.getenv("EMBEDDING_CACHE_MAX_TOTAL_MB", "5120")) + EMBEDDING_CACHE_EVICTION_BATCH = int( + os.getenv("EMBEDDING_CACHE_EVICTION_BATCH", "500") + ) # Proxy provider selection. Maps to a ProxyProvider implementation registered # in app/utils/proxy/registry.py. Add new vendors there and switch via this var. diff --git a/surfsense_backend/app/indexing_pipeline/cache/__init__.py b/surfsense_backend/app/indexing_pipeline/cache/__init__.py index 190b45f84..d3b9e5f0d 100644 --- a/surfsense_backend/app/indexing_pipeline/cache/__init__.py +++ b/surfsense_backend/app/indexing_pipeline/cache/__init__.py @@ -3,9 +3,9 @@ from __future__ import annotations from app.indexing_pipeline.cache.cached_indexing import build_chunk_embeddings -from app.indexing_pipeline.cache.service import IndexCacheService +from app.indexing_pipeline.cache.service import EmbeddingCacheService __all__ = [ - "IndexCacheService", + "EmbeddingCacheService", "build_chunk_embeddings", ] diff --git a/surfsense_backend/app/indexing_pipeline/cache/cached_indexing.py b/surfsense_backend/app/indexing_pipeline/cache/cached_indexing.py index db07998a4..c93f2f133 100644 --- a/surfsense_backend/app/indexing_pipeline/cache/cached_indexing.py +++ b/surfsense_backend/app/indexing_pipeline/cache/cached_indexing.py @@ -14,10 +14,10 @@ import logging import numpy as np from app.config import config -from app.indexing_pipeline.cache.eligibility import is_index_cacheable +from app.indexing_pipeline.cache.eligibility import is_embedding_cacheable from app.indexing_pipeline.cache.schemas import CachedChunk, EmbeddingKey, EmbeddingSet -from app.indexing_pipeline.cache.service import IndexCacheService -from app.indexing_pipeline.cache.settings import load_index_cache_settings +from app.indexing_pipeline.cache.service import EmbeddingCacheService +from app.indexing_pipeline.cache.settings import load_embedding_cache_settings from app.indexing_pipeline.document_chunker import chunk_text, chunk_text_hybrid from app.indexing_pipeline.document_embedder import embed_texts from app.observability import metrics @@ -35,11 +35,11 @@ async def build_chunk_embeddings( Drop-in for the inline chunk+embed step; reuses prior output when the same markdown has already been embedded with the current model and chunker. """ - settings = load_index_cache_settings() + settings = load_embedding_cache_settings() chunker_kind = "code" if use_code_chunker else "hybrid" embedding_dim = getattr(config.embedding_model_instance, "dimension", None) - cacheable = is_index_cacheable( + cacheable = is_embedding_cacheable( cache_enabled=settings.enabled, embedding_model=config.EMBEDDING_MODEL, embedding_dim=embedding_dim, @@ -57,13 +57,13 @@ async def build_chunk_embeddings( cached = await _recall(key) if cached is not None: - metrics.record_index_cache_lookup( + metrics.record_embedding_cache_lookup( embedding_model=key.embedding_model, chunker_kind=chunker_kind, outcome="hit" ) - logger.debug("Index cache hit for %s", key.markdown_sha256) + logger.debug("Embedding cache hit for %s", key.markdown_sha256) return cached.summary_embedding, [(c.text, c.embedding) for c in cached.chunks] - metrics.record_index_cache_lookup( + metrics.record_embedding_cache_lookup( embedding_model=key.embedding_model, chunker_kind=chunker_kind, outcome="miss" ) summary_embedding, chunk_pairs = await _compute( @@ -95,9 +95,9 @@ async def _recall(key: EmbeddingKey) -> EmbeddingSet | None: from app.tasks.celery_tasks import get_celery_session_maker async with get_celery_session_maker()() as session: - return await IndexCacheService(session).recall(key) + return await EmbeddingCacheService(session).recall(key) except Exception: - logger.warning("Index cache recall failed; embedding fresh", exc_info=True) + logger.warning("Embedding cache recall failed; embedding fresh", exc_info=True) return None @@ -112,9 +112,9 @@ async def _remember( chunks=[CachedChunk(text=text, embedding=vec) for text, vec in chunk_pairs], ) async with get_celery_session_maker()() as session: - await IndexCacheService(session).remember(key, embedding_set) + await EmbeddingCacheService(session).remember(key, embedding_set) except Exception: - logger.warning("Index cache write failed; result not cached", exc_info=True) + logger.warning("Embedding cache write failed; result not cached", exc_info=True) def _hash_text(text: str) -> str: diff --git a/surfsense_backend/app/indexing_pipeline/cache/eligibility.py b/surfsense_backend/app/indexing_pipeline/cache/eligibility.py index 7dbf79202..446bea2f8 100644 --- a/surfsense_backend/app/indexing_pipeline/cache/eligibility.py +++ b/surfsense_backend/app/indexing_pipeline/cache/eligibility.py @@ -1,9 +1,9 @@ -"""Gating rule: may this document be served from / written to the index cache?""" +"""Gating rule: may this document be served from / written to the embedding cache?""" from __future__ import annotations -def is_index_cacheable( +def is_embedding_cacheable( *, cache_enabled: bool, embedding_model: str | None, diff --git a/surfsense_backend/app/indexing_pipeline/cache/eviction/__init__.py b/surfsense_backend/app/indexing_pipeline/cache/eviction/__init__.py index de4df784e..a0f74b360 100644 --- a/surfsense_backend/app/indexing_pipeline/cache/eviction/__init__.py +++ b/surfsense_backend/app/indexing_pipeline/cache/eviction/__init__.py @@ -1,9 +1,9 @@ -"""Background pruning of the index cache by age and size budget.""" +"""Background pruning of the embedding cache by age and size budget.""" from __future__ import annotations -from .task import evict_index_cache_task +from .task import evict_embedding_cache_task __all__ = [ - "evict_index_cache_task", + "evict_embedding_cache_task", ] diff --git a/surfsense_backend/app/indexing_pipeline/cache/eviction/task.py b/surfsense_backend/app/indexing_pipeline/cache/eviction/task.py index ab6885bca..70eff6ea5 100644 --- a/surfsense_backend/app/indexing_pipeline/cache/eviction/task.py +++ b/surfsense_backend/app/indexing_pipeline/cache/eviction/task.py @@ -1,4 +1,4 @@ -"""Celery task that prunes the index cache by TTL, then by size budget.""" +"""Celery task that prunes the embedding cache by TTL, then by size budget.""" from __future__ import annotations @@ -10,7 +10,7 @@ from app.celery_app import celery_app from app.etl_pipeline.cache.eviction.policy import select_over_budget from app.etl_pipeline.cache.schemas import EvictionCandidate from app.indexing_pipeline.cache.persistence import CachedEmbeddingSetRepository -from app.indexing_pipeline.cache.settings import load_index_cache_settings +from app.indexing_pipeline.cache.settings import load_embedding_cache_settings from app.indexing_pipeline.cache.storage import EmbeddingCacheStore from app.observability import metrics from app.tasks.celery_tasks import get_celery_session_maker, run_async_celery_task @@ -18,14 +18,14 @@ from app.tasks.celery_tasks import get_celery_session_maker, run_async_celery_ta logger = logging.getLogger(__name__) -@celery_app.task(name="evict_index_cache") -def evict_index_cache_task(): +@celery_app.task(name="evict_embedding_cache") +def evict_embedding_cache_task(): return run_async_celery_task(_evict) async def _evict() -> None: """Expire stale entries, then shed the coldest overflow only if still over budget.""" - settings = load_index_cache_settings() + settings = load_embedding_cache_settings() if not settings.enabled: return @@ -64,5 +64,5 @@ async def _drop( with contextlib.suppress(Exception): await store.delete(candidate.storage_key) await index.delete_by_ids([candidate.id for candidate in candidates]) - metrics.record_index_cache_eviction(len(candidates), phase=phase) + metrics.record_embedding_cache_eviction(len(candidates), phase=phase) logger.info("Evicted %d cached embedding sets (%s)", len(candidates), phase) diff --git a/surfsense_backend/app/indexing_pipeline/cache/persistence/models.py b/surfsense_backend/app/indexing_pipeline/cache/persistence/models.py index e33e470f0..af34d92d2 100644 --- a/surfsense_backend/app/indexing_pipeline/cache/persistence/models.py +++ b/surfsense_backend/app/indexing_pipeline/cache/persistence/models.py @@ -1,4 +1,4 @@ -"""``index_cache_embedding_sets``: one reusable chunk+embedding set per markdown.""" +"""``embedding_cache_sets``: one reusable chunk+embedding set per markdown.""" from __future__ import annotations @@ -16,7 +16,7 @@ from app.db import BaseModel, TimestampMixin class CachedEmbeddingSet(BaseModel, TimestampMixin): - __tablename__ = "index_cache_embedding_sets" + __tablename__ = "embedding_cache_sets" # Key: markdown text + the recipe that turned it into vectors. markdown_sha256 = Column(String(64), nullable=False) @@ -41,7 +41,7 @@ class CachedEmbeddingSet(BaseModel, TimestampMixin): "embedding_model", "chunker_kind", "chunker_version", - name="uq_index_cache_embedding_sets_key", + name="uq_embedding_cache_sets_key", ), - Index("ix_index_cache_embedding_sets_last_used_at", "last_used_at"), + Index("ix_embedding_cache_sets_last_used_at", "last_used_at"), ) diff --git a/surfsense_backend/app/indexing_pipeline/cache/persistence/repository.py b/surfsense_backend/app/indexing_pipeline/cache/persistence/repository.py index 0bb0f8f23..f7f1f4345 100644 --- a/surfsense_backend/app/indexing_pipeline/cache/persistence/repository.py +++ b/surfsense_backend/app/indexing_pipeline/cache/persistence/repository.py @@ -1,4 +1,4 @@ -"""CRUD and eviction selectors for ``index_cache_embedding_sets`` (no business rules).""" +"""CRUD and eviction selectors for ``embedding_cache_sets`` (no business rules).""" from __future__ import annotations @@ -74,7 +74,7 @@ class CachedEmbeddingSetRepository: last_used_at=now, created_at=now, ) - .on_conflict_do_nothing(constraint="uq_index_cache_embedding_sets_key") + .on_conflict_do_nothing(constraint="uq_embedding_cache_sets_key") ) await self._session.commit() diff --git a/surfsense_backend/app/indexing_pipeline/cache/schemas/__init__.py b/surfsense_backend/app/indexing_pipeline/cache/schemas/__init__.py index 8714e2d86..c200ca1a6 100644 --- a/surfsense_backend/app/indexing_pipeline/cache/schemas/__init__.py +++ b/surfsense_backend/app/indexing_pipeline/cache/schemas/__init__.py @@ -1,4 +1,4 @@ -"""Pure value objects for the index cache.""" +"""Pure value objects for the embedding cache.""" from __future__ import annotations diff --git a/surfsense_backend/app/indexing_pipeline/cache/service.py b/surfsense_backend/app/indexing_pipeline/cache/service.py index 942ba7d51..b1d634782 100644 --- a/surfsense_backend/app/indexing_pipeline/cache/service.py +++ b/surfsense_backend/app/indexing_pipeline/cache/service.py @@ -13,7 +13,7 @@ from app.indexing_pipeline.cache.storage import EmbeddingCacheStore logger = logging.getLogger(__name__) -class IndexCacheService: +class EmbeddingCacheService: def __init__(self, session: AsyncSession) -> None: self._index = CachedEmbeddingSetRepository(session) self._store = EmbeddingCacheStore() diff --git a/surfsense_backend/app/indexing_pipeline/cache/settings.py b/surfsense_backend/app/indexing_pipeline/cache/settings.py index 2991c0980..9c6737445 100644 --- a/surfsense_backend/app/indexing_pipeline/cache/settings.py +++ b/surfsense_backend/app/indexing_pipeline/cache/settings.py @@ -1,4 +1,4 @@ -"""Index-cache configuration resolved from the central ``Config``. +"""Embedding-cache configuration resolved from the central ``Config``. The blob backend is intentionally not configured here: it is shared with the ETL parse cache (see ``ETL_CACHE_STORAGE_*``). @@ -10,7 +10,7 @@ from dataclasses import dataclass @dataclass(frozen=True) -class IndexCacheSettings: +class EmbeddingCacheSettings: enabled: bool chunker_version: int ttl_days: int @@ -18,13 +18,13 @@ class IndexCacheSettings: eviction_batch: int -def load_index_cache_settings() -> IndexCacheSettings: +def load_embedding_cache_settings() -> EmbeddingCacheSettings: from app.config import config - return IndexCacheSettings( - enabled=config.INDEX_CACHE_ENABLED, - chunker_version=config.INDEX_CACHE_CHUNKER_VERSION, - ttl_days=config.INDEX_CACHE_TTL_DAYS, - max_total_bytes=config.INDEX_CACHE_MAX_TOTAL_MB * 1024 * 1024, - eviction_batch=config.INDEX_CACHE_EVICTION_BATCH, + return EmbeddingCacheSettings( + enabled=config.EMBEDDING_CACHE_ENABLED, + chunker_version=config.EMBEDDING_CACHE_CHUNKER_VERSION, + ttl_days=config.EMBEDDING_CACHE_TTL_DAYS, + max_total_bytes=config.EMBEDDING_CACHE_MAX_TOTAL_MB * 1024 * 1024, + eviction_batch=config.EMBEDDING_CACHE_EVICTION_BATCH, ) diff --git a/surfsense_backend/app/indexing_pipeline/cache/storage/embedding_store.py b/surfsense_backend/app/indexing_pipeline/cache/storage/embedding_store.py index 48835a12b..58c4a6cc1 100644 --- a/surfsense_backend/app/indexing_pipeline/cache/storage/embedding_store.py +++ b/surfsense_backend/app/indexing_pipeline/cache/storage/embedding_store.py @@ -7,8 +7,8 @@ markdown and its embeddings live side by side; only the object prefix differs. from __future__ import annotations from app.etl_pipeline.cache.storage.backend import resolve_cache_backend -from app.indexing_pipeline.cache.serialization import deserialize, serialize from app.indexing_pipeline.cache.schemas import EmbeddingKey, EmbeddingSet +from app.indexing_pipeline.cache.serialization import deserialize, serialize from app.indexing_pipeline.cache.storage.object_keys import build_embedding_object_key _EMBEDDING_CONTENT_TYPE = "application/octet-stream" diff --git a/surfsense_backend/app/indexing_pipeline/cache/storage/object_keys.py b/surfsense_backend/app/indexing_pipeline/cache/storage/object_keys.py index 90e0b8957..6286ccf90 100644 --- a/surfsense_backend/app/indexing_pipeline/cache/storage/object_keys.py +++ b/surfsense_backend/app/indexing_pipeline/cache/storage/object_keys.py @@ -4,7 +4,7 @@ from __future__ import annotations from app.indexing_pipeline.cache.schemas import EmbeddingKey -CACHE_PREFIX = "index_cache" +CACHE_PREFIX = "embedding_cache" def build_embedding_object_key(key: EmbeddingKey) -> str: diff --git a/surfsense_backend/app/observability/metrics.py b/surfsense_backend/app/observability/metrics.py index 61b380722..94bb55740 100644 --- a/surfsense_backend/app/observability/metrics.py +++ b/surfsense_backend/app/observability/metrics.py @@ -306,18 +306,18 @@ def _etl_cache_evictions(): @lru_cache(maxsize=1) -def _index_cache_lookups(): +def _embedding_cache_lookups(): return _get_meter().create_counter( - "surfsense.index.cache.lookups", - description="Count of index (chunk+embedding) cache lookups by outcome (hit/miss).", + "surfsense.embedding.cache.lookups", + description="Count of embedding (chunk+embedding) cache lookups by outcome (hit/miss).", ) @lru_cache(maxsize=1) -def _index_cache_evictions(): +def _embedding_cache_evictions(): return _get_meter().create_counter( - "surfsense.index.cache.evictions", - description="Count of index cache entries evicted, by phase.", + "surfsense.embedding.cache.evictions", + description="Count of embedding cache entries evicted, by phase.", ) @@ -724,12 +724,12 @@ def record_etl_cache_eviction(count: int, *, phase: str) -> None: _add(_etl_cache_evictions(), count, {"phase": phase}) -def record_index_cache_lookup( +def record_embedding_cache_lookup( *, embedding_model: str | None, chunker_kind: str | None, outcome: str ) -> None: - """Record an index-cache lookup. ``outcome`` is ``hit`` or ``miss``.""" + """Record an embedding-cache lookup. ``outcome`` is ``hit`` or ``miss``.""" _add( - _index_cache_lookups(), + _embedding_cache_lookups(), 1, { "embedding.model": embedding_model or "unknown", @@ -739,11 +739,11 @@ def record_index_cache_lookup( ) -def record_index_cache_eviction(count: int, *, phase: str) -> None: +def record_embedding_cache_eviction(count: int, *, phase: str) -> None: """Record evicted entries. ``phase`` is ``ttl`` or ``size``.""" if count <= 0: return - _add(_index_cache_evictions(), count, {"phase": phase}) + _add(_embedding_cache_evictions(), count, {"phase": phase}) def record_celery_heartbeat_refresh(*, heartbeat_type: str) -> None: @@ -942,12 +942,12 @@ __all__ = [ "record_compaction_run", "record_connector_sync_duration", "record_connector_sync_outcome", + "record_embedding_cache_eviction", + "record_embedding_cache_lookup", "record_etl_cache_eviction", "record_etl_cache_lookup", "record_etl_extract_duration", "record_etl_extract_outcome", - "record_index_cache_eviction", - "record_index_cache_lookup", "record_indexing_document_duration", "record_indexing_document_outcome", "record_interrupt", diff --git a/surfsense_backend/tests/unit/indexing_pipeline/cache/test_eligibility.py b/surfsense_backend/tests/unit/indexing_pipeline/cache/test_eligibility.py index 780a6c536..2e488231c 100644 --- a/surfsense_backend/tests/unit/indexing_pipeline/cache/test_eligibility.py +++ b/surfsense_backend/tests/unit/indexing_pipeline/cache/test_eligibility.py @@ -1,28 +1,28 @@ -from app.indexing_pipeline.cache.eligibility import is_index_cacheable +from app.indexing_pipeline.cache.eligibility import is_embedding_cacheable def test_disabled_cache_is_never_cacheable(): - assert not is_index_cacheable( + assert not is_embedding_cacheable( cache_enabled=False, embedding_model="m", embedding_dim=384 ) def test_missing_model_is_not_cacheable(): - assert not is_index_cacheable( + assert not is_embedding_cacheable( cache_enabled=True, embedding_model=None, embedding_dim=384 ) def test_missing_dimension_is_not_cacheable(): - assert not is_index_cacheable( + assert not is_embedding_cacheable( cache_enabled=True, embedding_model="m", embedding_dim=None ) - assert not is_index_cacheable( + assert not is_embedding_cacheable( cache_enabled=True, embedding_model="m", embedding_dim=0 ) def test_enabled_with_model_and_dim_is_cacheable(): - assert is_index_cacheable( + assert is_embedding_cacheable( cache_enabled=True, embedding_model="m", embedding_dim=384 ) From 412493ae08ab6d6333ece255334fec7c67c4d53d Mon Sep 17 00:00:00 2001 From: CREDO23 Date: Fri, 12 Jun 2026 17:33:21 +0200 Subject: [PATCH 084/212] test(embedding-cache): add integration tests for service, repository, and store Covers the public cache surface against real Postgres and a real local file backend (no mocks): recall miss, remember->recall vector/text/order round-trip, the dimension-mismatch refusal, the repository SQL behind eviction and dedup (size sum, coldest ordering, TTL cutoff, duplicate-key no-op, reuse counter), and the blob store save/load round-trip and delete. --- .../indexing_pipeline/cache/conftest.py | 33 ++++++ .../cache/test_cached_embedding_repository.py | 110 ++++++++++++++++++ .../cache/test_embedding_cache_service.py | 70 +++++++++++ .../cache/test_embedding_store.py | 63 ++++++++++ 4 files changed, 276 insertions(+) create mode 100644 surfsense_backend/tests/integration/indexing_pipeline/cache/conftest.py create mode 100644 surfsense_backend/tests/integration/indexing_pipeline/cache/test_cached_embedding_repository.py create mode 100644 surfsense_backend/tests/integration/indexing_pipeline/cache/test_embedding_cache_service.py create mode 100644 surfsense_backend/tests/integration/indexing_pipeline/cache/test_embedding_store.py diff --git a/surfsense_backend/tests/integration/indexing_pipeline/cache/conftest.py b/surfsense_backend/tests/integration/indexing_pipeline/cache/conftest.py new file mode 100644 index 000000000..6acb457ee --- /dev/null +++ b/surfsense_backend/tests/integration/indexing_pipeline/cache/conftest.py @@ -0,0 +1,33 @@ +"""Real-infra fixtures for the embedding-cache integration tests. + +``cache_local_storage`` points the shared cache backend at a throwaway directory +so tests exercise the real ``LocalFileBackend`` (no cloud, no mocks); the +embedding cache reuses the ETL cache backend, hence the ``ETL_CACHE_STORAGE_*`` +knobs. ``clean_embedding_cache_table`` removes rows written through the store's +own committing session, which the savepoint-rolled-back ``db_session`` cannot undo. +""" + +from __future__ import annotations + +import pytest +import pytest_asyncio +from sqlalchemy import text + + +@pytest.fixture +def cache_local_storage(tmp_path, monkeypatch): + from app.config import config + from app.etl_pipeline.cache.storage.backend import resolve_cache_backend + + monkeypatch.setattr(config, "ETL_CACHE_STORAGE_BACKEND", "local") + monkeypatch.setattr(config, "ETL_CACHE_STORAGE_LOCAL_PATH", str(tmp_path)) + resolve_cache_backend.cache_clear() + yield tmp_path + resolve_cache_backend.cache_clear() + + +@pytest_asyncio.fixture +async def clean_embedding_cache_table(async_engine): + yield + async with async_engine.begin() as conn: + await conn.execute(text("DELETE FROM embedding_cache_sets")) diff --git a/surfsense_backend/tests/integration/indexing_pipeline/cache/test_cached_embedding_repository.py b/surfsense_backend/tests/integration/indexing_pipeline/cache/test_cached_embedding_repository.py new file mode 100644 index 000000000..446932793 --- /dev/null +++ b/surfsense_backend/tests/integration/indexing_pipeline/cache/test_cached_embedding_repository.py @@ -0,0 +1,110 @@ +"""CachedEmbeddingSetRepository against real Postgres: the SQL behind eviction & dedup. + +These verify the parts only a real database can: the size accumulator, +coldest-first ordering by reuse then recency, TTL cutoff selection, the +insert-once guarantee under a duplicate key, and the reuse counter. +""" + +from __future__ import annotations + +from datetime import UTC, datetime, timedelta + +import pytest + +from app.indexing_pipeline.cache.persistence import CachedEmbeddingSetRepository +from app.indexing_pipeline.cache.schemas import EmbeddingKey + +pytestmark = pytest.mark.integration + + +def _key(sha: str) -> EmbeddingKey: + return EmbeddingKey( + markdown_sha256=sha, + embedding_model="test-model", + embedding_dim=4, + chunker_kind="hybrid", + chunker_version=1, + ) + + +async def _insert(repo, *, sha, size=100, storage_key=None, chunk_count=1): + key = _key(sha) + await repo.insert( + key=key, + storage_backend="local", + storage_key=storage_key or f"embedding_cache/{sha}.emb", + size_bytes=size, + chunk_count=chunk_count, + ) + return key + + +async def test_total_size_bytes_sums_all_rows(db_session): + repo = CachedEmbeddingSetRepository(db_session) + await _insert(repo, sha="a" * 64, size=100) + await _insert(repo, sha="b" * 64, size=250) + + assert await repo.total_size_bytes() == 350 + + +async def test_select_coldest_orders_by_reuse_then_recency(db_session): + repo = CachedEmbeddingSetRepository(db_session) + ka = await _insert(repo, sha="a" * 64) + kb = await _insert(repo, sha="b" * 64) + kc = await _insert(repo, sha="c" * 64) + + # Warm B once and C twice; A stays untouched and should be coldest. + await repo.mark_used((await repo.get(kb)).id) + await repo.mark_used((await repo.get(kc)).id) + await repo.mark_used((await repo.get(kc)).id) + + coldest = await repo.select_coldest(limit=10) + + assert [c.id for c in coldest][:3] == [ + (await repo.get(ka)).id, + (await repo.get(kb)).id, + (await repo.get(kc)).id, + ] + + +async def test_select_expired_returns_only_rows_older_than_cutoff(db_session): + repo = CachedEmbeddingSetRepository(db_session) + await _insert(repo, sha="a" * 64) + + future = datetime.now(UTC) + timedelta(days=1) + past = datetime.now(UTC) - timedelta(days=1) + + # Row was just used, so it predates a future cutoff but not a past one. + assert len(await repo.select_expired(cutoff=future, limit=10)) == 1 + assert await repo.select_expired(cutoff=past, limit=10) == [] + + +async def test_duplicate_key_insert_keeps_the_first_row(db_session): + repo = CachedEmbeddingSetRepository(db_session) + key = await _insert( + repo, sha="a" * 64, size=100, storage_key="embedding_cache/first.emb" + ) + + # Same content-addressed key (a concurrent re-embed): must be a no-op. + await repo.insert( + key=key, + storage_backend="local", + storage_key="embedding_cache/second.emb", + size_bytes=999, + chunk_count=42, + ) + + row = await repo.get(key) + assert row.storage_key == "embedding_cache/first.emb" + assert await repo.total_size_bytes() == 100 + + +async def test_mark_used_increments_reuse_count(db_session): + repo = CachedEmbeddingSetRepository(db_session) + key = await _insert(repo, sha="a" * 64) + assert (await repo.get(key)).times_reused == 0 + + await repo.mark_used((await repo.get(key)).id) + await repo.mark_used((await repo.get(key)).id) + + assert (await repo.get(key)).times_reused == 2 diff --git a/surfsense_backend/tests/integration/indexing_pipeline/cache/test_embedding_cache_service.py b/surfsense_backend/tests/integration/indexing_pipeline/cache/test_embedding_cache_service.py new file mode 100644 index 000000000..2f4cd4a89 --- /dev/null +++ b/surfsense_backend/tests/integration/indexing_pipeline/cache/test_embedding_cache_service.py @@ -0,0 +1,70 @@ +"""EmbeddingCacheService end-to-end against real Postgres + real local storage. + +Exercises the public cache surface -- ``recall`` / ``remember`` -- with no mocks: +a miss returns nothing, a remembered set comes back as equivalent vectors, and a +dimension mismatch is refused rather than served. +""" + +from __future__ import annotations + +import numpy as np +import pytest + +from app.indexing_pipeline.cache.schemas import CachedChunk, EmbeddingKey, EmbeddingSet +from app.indexing_pipeline.cache.service import EmbeddingCacheService + +pytestmark = pytest.mark.integration + + +def _key(sha: str = "c" * 64, *, dim: int = 4) -> EmbeddingKey: + return EmbeddingKey( + markdown_sha256=sha, + embedding_model="test-model", + embedding_dim=dim, + chunker_kind="hybrid", + chunker_version=1, + ) + + +async def test_recall_is_a_miss_for_an_unknown_key(db_session, cache_local_storage): + service = EmbeddingCacheService(db_session) + assert await service.recall(_key()) is None + + +async def test_remembered_set_recalls_as_equivalent_vectors( + db_session, cache_local_storage, clean_embedding_cache_table +): + service = EmbeddingCacheService(db_session) + stored = EmbeddingSet( + summary_embedding=np.array([0.1, 0.2, 0.3, 0.4], dtype=np.float32), + chunks=[ + CachedChunk("first chunk", np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32)), + CachedChunk("second chunk", np.array([0.0, 1.0, 0.0, 0.0], dtype=np.float32)), + ], + ) + + await service.remember(_key(), stored) + recalled = await service.recall(_key()) + + assert recalled is not None + assert np.array_equal(recalled.summary_embedding, stored.summary_embedding) + assert [c.text for c in recalled.chunks] == ["first chunk", "second chunk"] + assert np.array_equal(recalled.chunks[0].embedding, stored.chunks[0].embedding) + assert np.array_equal(recalled.chunks[1].embedding, stored.chunks[1].embedding) + + +async def test_recall_refuses_a_set_whose_dimension_changed( + db_session, cache_local_storage, clean_embedding_cache_table +): + # A model kept its name but changed its output width: never serve the stale blob. + service = EmbeddingCacheService(db_session) + stored = EmbeddingSet( + summary_embedding=np.array([0.1, 0.2, 0.3, 0.4], dtype=np.float32), + chunks=[CachedChunk("c", np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32))], + ) + await service.remember(_key(dim=4), stored) + + # Same identity (model + chunker + markdown), but the caller now expects dim 8. + recalled = await service.recall(_key(dim=8)) + + assert recalled is None diff --git a/surfsense_backend/tests/integration/indexing_pipeline/cache/test_embedding_store.py b/surfsense_backend/tests/integration/indexing_pipeline/cache/test_embedding_store.py new file mode 100644 index 000000000..83becd7b5 --- /dev/null +++ b/surfsense_backend/tests/integration/indexing_pipeline/cache/test_embedding_store.py @@ -0,0 +1,63 @@ +"""EmbeddingCacheStore against a real local filesystem backend (no mocks). + +Proves the blob side of the cache: an embedding set written under a +content-addressed key comes back with identical vectors, and a delete actually +removes it. +""" + +from __future__ import annotations + +import numpy as np +import pytest + +from app.indexing_pipeline.cache.schemas import CachedChunk, EmbeddingKey, EmbeddingSet +from app.indexing_pipeline.cache.storage import EmbeddingCacheStore +from app.indexing_pipeline.cache.storage.object_keys import build_embedding_object_key + +pytestmark = pytest.mark.integration + + +def _key() -> EmbeddingKey: + return EmbeddingKey( + markdown_sha256="d" * 64, + embedding_model="test-model", + embedding_dim=4, + chunker_kind="hybrid", + chunker_version=1, + ) + + +def _set() -> EmbeddingSet: + return EmbeddingSet( + summary_embedding=np.array([0.5, 0.25, 0.125, 0.0625], dtype=np.float32), + chunks=[ + CachedChunk("café, naïve, 漢字", np.array([1, 2, 3, 4], dtype=np.float32)), + CachedChunk("second", np.array([5, 6, 7, 8], dtype=np.float32)), + ], + ) + + +async def test_save_then_load_round_trips_the_embedding_set(cache_local_storage): + store = EmbeddingCacheStore() + embedding_set = _set() + + storage_key, size_bytes = await store.save(_key(), embedding_set) + loaded = await store.load(storage_key) + + assert storage_key == build_embedding_object_key(_key()) + assert size_bytes > 0 + assert np.array_equal(loaded.summary_embedding, embedding_set.summary_embedding) + assert [c.text for c in loaded.chunks] == ["café, naïve, 漢字", "second"] + assert np.array_equal(loaded.chunks[0].embedding, embedding_set.chunks[0].embedding) + assert np.array_equal(loaded.chunks[1].embedding, embedding_set.chunks[1].embedding) + + +async def test_delete_removes_the_blob(cache_local_storage): + store = EmbeddingCacheStore() + storage_key, _ = await store.save(_key(), _set()) + + await store.delete(storage_key) + + # Eviction deleted the blob; a later read must fail rather than serve stale. + with pytest.raises(FileNotFoundError): + await store.load(storage_key) From 356f0e56c5d5b94843436328b0ca5329299800a4 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Fri, 12 Jun 2026 22:04:44 +0530 Subject: [PATCH 085/212] feat(model-connections): add provider-specific connection forms and shared components --- .../search-space-settings/layout-shell.tsx | 4 +- .../free-chat/free-model-selector.tsx | 4 +- .../components/new-chat/model-selector.tsx | 4 +- .../settings/model-connections-settings.tsx | 762 +++++------------- .../model-connections/azure-connect-form.tsx | 59 ++ .../bedrock-connect-form.tsx | 134 +++ .../model-connections/connect-fields.tsx | 83 ++ .../connection-settings-dialog.tsx | 249 ++++++ .../default-connect-form.tsx | 51 ++ .../settings/model-connections/model-utils.ts | 25 + .../models-selection-panel.tsx | 196 +++++ .../provider-connect-dialog.tsx | 154 ++++ .../model-connections/provider-metadata.tsx | 139 ++++ .../model-connections/vertex-connect-form.tsx | 127 +++ surfsense_web/lib/provider-icons.tsx | 6 +- 15 files changed, 1423 insertions(+), 574 deletions(-) create mode 100644 surfsense_web/components/settings/model-connections/azure-connect-form.tsx create mode 100644 surfsense_web/components/settings/model-connections/bedrock-connect-form.tsx create mode 100644 surfsense_web/components/settings/model-connections/connect-fields.tsx create mode 100644 surfsense_web/components/settings/model-connections/connection-settings-dialog.tsx create mode 100644 surfsense_web/components/settings/model-connections/default-connect-form.tsx create mode 100644 surfsense_web/components/settings/model-connections/model-utils.ts create mode 100644 surfsense_web/components/settings/model-connections/models-selection-panel.tsx create mode 100644 surfsense_web/components/settings/model-connections/provider-connect-dialog.tsx create mode 100644 surfsense_web/components/settings/model-connections/provider-metadata.tsx create mode 100644 surfsense_web/components/settings/model-connections/vertex-connect-form.tsx diff --git a/surfsense_web/app/dashboard/[search_space_id]/search-space-settings/layout-shell.tsx b/surfsense_web/app/dashboard/[search_space_id]/search-space-settings/layout-shell.tsx index 9d9045004..d30ea8a3a 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/search-space-settings/layout-shell.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/search-space-settings/layout-shell.tsx @@ -2,8 +2,8 @@ import { BookText, - Bot, CircleUser, + Cpu, Earth, UserKey, } from "lucide-react"; @@ -54,7 +54,7 @@ export function SearchSpaceSettingsLayoutShell({ { value: "models" as const, label: t("nav_models"), - icon: , + icon: , }, { value: "team-roles" as const, diff --git a/surfsense_web/components/free-chat/free-model-selector.tsx b/surfsense_web/components/free-chat/free-model-selector.tsx index 9bf4ecee5..d04bca8a2 100644 --- a/surfsense_web/components/free-chat/free-model-selector.tsx +++ b/surfsense_web/components/free-chat/free-model-selector.tsx @@ -1,6 +1,6 @@ "use client"; -import { Bot, Check, ChevronDown } from "lucide-react"; +import { Check, ChevronDown, Cpu } from "lucide-react"; import { useRouter } from "next/navigation"; import { useCallback, useEffect, useMemo, useState } from "react"; import { Badge } from "@/components/ui/badge"; @@ -82,7 +82,7 @@ export function FreeModelSelector({ className }: { className?: string }) { ) : ( <> - + Select Model )} diff --git a/surfsense_web/components/new-chat/model-selector.tsx b/surfsense_web/components/new-chat/model-selector.tsx index 6850096d6..dd1bcb431 100644 --- a/surfsense_web/components/new-chat/model-selector.tsx +++ b/surfsense_web/components/new-chat/model-selector.tsx @@ -1,7 +1,7 @@ "use client"; import { useAtom, useAtomValue } from "jotai"; -import { Bot, Check, ChevronDown, ImageOff, Search, Settings2, Zap } from "lucide-react"; +import { Check, ChevronDown, Cpu, ImageOff, Search, Settings2, Zap } from "lucide-react"; import { useMemo, useState } from "react"; import { updateModelRolesMutationAtom } from "@/atoms/model-connections/model-connections-mutation.atoms"; import { @@ -222,7 +222,7 @@ export function ModelSelector({ {selected ? ( getProviderIcon(selected.provider, { className: "size-4" }) ) : ( - + )} {selected ? modelName(selected) : "Auto"} diff --git a/surfsense_web/components/settings/model-connections-settings.tsx b/surfsense_web/components/settings/model-connections-settings.tsx index 9112cfe64..2e15ce2e9 100644 --- a/surfsense_web/components/settings/model-connections-settings.tsx +++ b/surfsense_web/components/settings/model-connections-settings.tsx @@ -1,17 +1,7 @@ "use client"; import { useAtom, useAtomValue } from "jotai"; -import { - Check, - CheckCircle2, - ChevronsUpDown, - Eye, - EyeOff, - RefreshCcw, - Settings, - Trash2, - XCircle, -} from "lucide-react"; +import { CheckCircle2, Trash2, XCircle } from "lucide-react"; import { useState } from "react"; import { addManualModelMutationAtom, @@ -19,10 +9,8 @@ import { createModelConnectionMutationAtom, deleteModelConnectionMutationAtom, discoverConnectionModelsMutationAtom, - updateModelConnectionMutationAtom, updateModelMutationAtom, updateModelRolesMutationAtom, - verifyModelConnectionMutationAtom, } from "@/atoms/model-connections/model-connections-mutation.atoms"; import { globalModelConnectionsAtom, @@ -44,27 +32,7 @@ import { import { Badge } from "@/components/ui/badge"; import { Button } from "@/components/ui/button"; import { Card, CardContent, CardDescription, CardHeader, CardTitle } from "@/components/ui/card"; -import { Checkbox } from "@/components/ui/checkbox"; -import { - Command, - CommandEmpty, - CommandGroup, - CommandInput, - CommandItem, - CommandList, -} from "@/components/ui/command"; -import { - Dialog, - DialogContent, - DialogDescription, - DialogFooter, - DialogHeader, - DialogTitle, - DialogTrigger, -} from "@/components/ui/dialog"; -import { Input } from "@/components/ui/input"; import { Label } from "@/components/ui/label"; -import { Popover, PopoverContent, PopoverTrigger } from "@/components/ui/popover"; import { Select, SelectContent, @@ -73,108 +41,16 @@ import { SelectValue, } from "@/components/ui/select"; import { Separator } from "@/components/ui/separator"; -import type { - ConnectionRead, - ConnectionUpdateRequest, - ModelRead, -} from "@/contracts/types/model-connections.types"; -import { getProviderIcon } from "@/lib/provider-icons"; -import { cn } from "@/lib/utils"; - -// Free-text URL hints (datalist), mirroring OpenWebUI. These never restrict -// what the user can type — any OpenAI-compatible endpoint works. -const URL_SUGGESTIONS = [ - "https://api.openai.com/v1", - "https://api.anthropic.com/v1", - "https://openrouter.ai/api/v1", - "https://generativelanguage.googleapis.com/v1beta/openai", - "https://api.groq.com/openai/v1", - "https://api.mistral.ai/v1", - "https://api.deepseek.com/v1", - "https://api.x.ai/v1", - "http://host.docker.internal:11434", - "http://host.docker.internal:1234/v1", - "http://host.docker.internal:8000/v1", -]; - -function modelLabel(model: ModelRead) { - return model.display_name || model.model_id; -} - -function capability(model: ModelRead, key: "chat" | "vision" | "image_gen") { - if (key === "chat") return Boolean(model.supports_chat); - if (key === "vision") return Boolean(model.supports_image_input); - return Boolean(model.supports_image_generation); -} - -type ModelCapabilityFilter = "chat" | "vision" | "image_gen"; - -const MODEL_CAPABILITY_FILTERS: { key: ModelCapabilityFilter; label: string }[] = [ - { key: "chat", label: "Chat" }, - { key: "vision", label: "Vision" }, - { key: "image_gen", label: "Image" }, -]; - -function UrlSuggestionCombobox({ - value, - onChange, - placeholder, -}: { - value: string; - onChange: (value: string) => void; - placeholder: string; -}) { - const [open, setOpen] = useState(false); - - return ( - - - - - - - - - - Use the custom URL you typed - - - {URL_SUGGESTIONS.map((url) => ( - { - onChange(url); - setOpen(false); - }} - > - - {url} - - ))} - - - - - - ); -} +import type { ConnectionRead, ModelRead } from "@/contracts/types/model-connections.types"; +import { ConnectionSettingsDialog } from "./model-connections/connection-settings-dialog"; +import { capability, modelLabel } from "./model-connections/model-utils"; +import { ProviderConnectDialog } from "./model-connections/provider-connect-dialog"; +import { + type ConnectionDraft, + PROVIDER_ORDER, + providerDisplay, + providerIcon, +} from "./model-connections/provider-metadata"; function StatusBadge({ connection }: { connection: ConnectionRead }) { if (connection.last_status === "OK") { @@ -198,7 +74,7 @@ function flattenModels(connections: ConnectionRead[]) { return connections.flatMap((connection) => connection.models.map((model) => ({ ...model, - connectionName: connection.provider, + connectionName: providerDisplay(connection.provider).name, connectionId: connection.id, provider: connection.provider, })) @@ -206,110 +82,21 @@ function flattenModels(connections: ConnectionRead[]) { } function ConnectionCard({ connection }: { connection: ConnectionRead }) { - const verifyConnection = useAtomValue(verifyModelConnectionMutationAtom); - const discoverModels = useAtomValue(discoverConnectionModelsMutationAtom); - const updateConnection = useAtomValue(updateModelConnectionMutationAtom); const deleteConnection = useAtomValue(deleteModelConnectionMutationAtom); - const addManualModel = useAtomValue(addManualModelMutationAtom); - const updateModel = useAtomValue(updateModelMutationAtom); - const bulkUpdateModels = useAtomValue(bulkUpdateModelsMutationAtom); - const allowlist = Array.isArray(connection.extra?.model_ids) - ? (connection.extra.model_ids as string[]) - : []; - const [isSettingsOpen, setIsSettingsOpen] = useState(false); - const [baseUrlDraft, setBaseUrlDraft] = useState(connection.base_url ?? ""); - const [apiKeyDraft, setApiKeyDraft] = useState(""); - const [showApiKey, setShowApiKey] = useState(false); - const [allowlistText, setAllowlistText] = useState(allowlist.join(", ")); - const [manualModelId, setManualModelId] = useState(""); - const [modelFilter, setModelFilter] = useState(null); - - const providerLabel = connection.provider; - const isLocal = - connection.provider === "ollama_chat" || - connection.provider === "lm_studio" || - !connection.base_url?.startsWith("https"); - const filteredModels = modelFilter - ? connection.models.filter((model) => capability(model, modelFilter)) - : connection.models; - const allFilteredModelsEnabled = - filteredModels.length > 0 && filteredModels.every((model) => model.enabled); - const hasConnectionChanges = - baseUrlDraft.trim() !== (connection.base_url ?? "") || - apiKeyDraft.trim() !== (connection.api_key ?? ""); - - function handleSettingsOpenChange(open: boolean) { - setIsSettingsOpen(open); - if (open) { - setBaseUrlDraft(connection.base_url ?? ""); - setApiKeyDraft(connection.api_key ?? ""); - setShowApiKey(false); - setAllowlistText(allowlist.join(", ")); - } - } - - function saveConnectionSettings() { - const data: ConnectionUpdateRequest = { - base_url: baseUrlDraft.trim() || null, - }; - - if (apiKeyDraft.trim() !== (connection.api_key ?? "")) { - data.api_key = apiKeyDraft.trim() || null; - } - - updateConnection.mutate( - { id: connection.id, data }, - { - onSuccess: () => setApiKeyDraft(""), - } - ); - } - - function saveAllowlist() { - const ids = allowlistText - .split(",") - .map((value) => value.trim()) - .filter(Boolean); - updateConnection.mutate({ - id: connection.id, - data: { extra: { ...(connection.extra ?? {}), model_ids: ids } }, - }); - } - - function addModel() { - const modelId = manualModelId.trim(); - if (!modelId) return; - addManualModel.mutate( - { connectionId: connection.id, data: { model_id: modelId } }, - { onSuccess: () => setManualModelId("") } - ); - } + const providerMeta = providerDisplay(connection.provider); + const providerLabel = providerMeta.name; function deleteCurrentConnection() { deleteConnection.mutate(connection.id); } - function toggleFilteredModels() { - const nextEnabled = !allFilteredModelsEnabled; - const modelIds = filteredModels - .filter((model) => model.enabled !== nextEnabled) - .map((model) => model.id); - - if (modelIds.length === 0) return; - - bulkUpdateModels.mutate({ - connectionId: connection.id, - data: { model_ids: modelIds, enabled: nextEnabled }, - }); - } - return ( -
-
+
+
- {getProviderIcon(providerLabel, { className: "size-4" })} + {providerIcon(connection.provider)} {providerLabel} {connection.scope === "GLOBAL" ? ( @@ -323,253 +110,7 @@ function ConnectionCard({ connection }: { connection: ConnectionRead }) {
- - - - - - -
- {getProviderIcon(providerLabel, { className: "size-5" })} -
- - Configure {providerLabel} - - - Manage credentials and choose which models are available from this provider. - -
-
-
- -
-
-
- - -

- Leave empty to use the provider default endpoint. -

-
- -
- -
- setApiKeyDraft(event.target.value)} - placeholder={connection.has_api_key ? "Saved API key" : "Paste an API key"} - type={showApiKey ? "text" : "password"} - className="pr-11" - /> - -
-
- - {!isLocal ? ( -
- -
- setAllowlistText(event.target.value)} - placeholder="Comma-separated, e.g. anthropic/claude-sonnet-4-5, google/gemini-2.5-pro" - /> - -
-

- Leave empty to discover all models. Recommended for providers with large - catalogs. -

-
- ) : null} - - - -
-
-
-
Models
-

- Select models to make available for this provider. -

-
-
- - -
-
- -
- setManualModelId(event.target.value)} - onKeyDown={(event) => { - if (event.key === "Enter") { - event.preventDefault(); - addModel(); - } - }} - placeholder="Add a model ID manually" - /> - -
- - {connection.models.length > 0 ? ( -
- - Filter models - - {MODEL_CAPABILITY_FILTERS.map((filter) => { - const count = connection.models.filter((model) => - capability(model, filter.key) - ).length; - const isActive = modelFilter === filter.key; - - return ( - - ); - })} -
- ) : null} - -
- {connection.models.length === 0 ? ( -
- No models yet. Use the refresh button to discover models or add one - manually. -
- ) : null} - {filteredModels.length === 0 && modelFilter ? ( -
- No{" "} - {MODEL_CAPABILITY_FILTERS.find( - (filter) => filter.key === modelFilter - )?.label.toLowerCase()}{" "} - models found on this connection. -
- ) : null} -
- {filteredModels.map((model) => ( -
- - updateModel.mutate({ - id: model.id, - data: { enabled: checked === true }, - }) - } - disabled={updateModel.isPending} - /> -
-
- {modelLabel(model)} - {model.source === "MANUAL" ? ( - - manual - - ) : null} -
-
- {["chat", "vision", "image_gen"] - .filter((key) => - capability(model, key as "chat" | "vision" | "image_gen") - ) - .join(", ") || "No discovered capabilities"} -
-
-
- ))} -
-
-
- - {connection.last_status && connection.last_status !== "OK" ? ( -

- {connection.last_error || "Could not list models."} Chat may still work; add - model IDs manually if discovery is unavailable. -

- ) : null} -
-
- - - - - -
-
+ -
-
-
+
+
+
+

Add Provider

- {selectedProvider - ? `${selectedProvider.transport} transport, ${selectedProvider.discovery} discovery.` - : "Choose a provider preset."}{" "} - Base URL is explicit and editable. Local URLs are tested from the backend container, - so use host.docker.internal instead of localhost. + SurfSense supports popular providers and self-hosted model endpoints.

+
+ {sortedProviders.map((item) => { + const meta = providerDisplay(item.provider); - {connections.length > 0 ? ( + return ( + + ); + })} +
+
+ + + + {connections.length > 0 ? ( +
+ +

Available Providers

- -

Available Providers

-
- {connections.map((connection) => ( - - ))} -
+ {connections.map((connection) => ( + + ))}
- ) : null} - - +
+ ) : null} +
diff --git a/surfsense_web/components/settings/model-connections/azure-connect-form.tsx b/surfsense_web/components/settings/model-connections/azure-connect-form.tsx new file mode 100644 index 000000000..11a2e25d3 --- /dev/null +++ b/surfsense_web/components/settings/model-connections/azure-connect-form.tsx @@ -0,0 +1,59 @@ +import { useState } from "react"; +import { Input } from "@/components/ui/input"; +import { Label } from "@/components/ui/label"; +import { ApiKeyField, ConnectFormFooter } from "./connect-fields"; +import { + isValidAzureTargetUri, + type ProviderConnectFormProps, + parseAzureTargetUri, +} from "./provider-metadata"; + +/** + * Azure OpenAI connect form. The user pastes a single Target URI, which we parse + * into api base, api version, and the deployment name (seeded as the model). + */ +export function AzureConnectForm({ isPending, onCancel, onSubmit }: ProviderConnectFormProps) { + const [targetUri, setTargetUri] = useState(""); + const [apiKey, setApiKey] = useState(""); + const canSubmit = isValidAzureTargetUri(targetUri) && Boolean(apiKey.trim()); + + function handleSubmit() { + const parsed = parseAzureTargetUri(targetUri); + onSubmit({ + base_url: parsed?.origin ?? null, + api_key: apiKey || null, + extra: parsed?.apiVersion ? { api_version: parsed.apiVersion } : {}, + seedModelId: parsed?.deploymentName || undefined, + }); + } + + return ( + <> +
+
+ + setTargetUri(event.target.value)} + placeholder="https://your-resource.cognitiveservices.azure.com/openai/deployments/deployment-name/chat/completions?api-version=2025-01-01-preview" + /> +

+ Paste your endpoint target URI from Azure OpenAI (including API base, deployment name, + and API version). +

+
+ +
+ + + ); +} diff --git a/surfsense_web/components/settings/model-connections/bedrock-connect-form.tsx b/surfsense_web/components/settings/model-connections/bedrock-connect-form.tsx new file mode 100644 index 000000000..9466c7cd1 --- /dev/null +++ b/surfsense_web/components/settings/model-connections/bedrock-connect-form.tsx @@ -0,0 +1,134 @@ +import { useState } from "react"; +import { Input } from "@/components/ui/input"; +import { Label } from "@/components/ui/label"; +import { + Select, + SelectContent, + SelectItem, + SelectTrigger, + SelectValue, +} from "@/components/ui/select"; +import { ConnectFormFooter } from "./connect-fields"; +import { + AWS_REGION_OPTIONS, + BEDROCK_AUTH_ACCESS_KEY, + BEDROCK_AUTH_IAM, + BEDROCK_AUTH_LONG_TERM_API_KEY, + type ProviderConnectFormProps, +} from "./provider-metadata"; + +/** + * Amazon Bedrock connect form. Region + auth method drive which AWS credentials + * are collected; everything rides along in `extra.litellm_params`. + */ +export function BedrockConnectForm({ isPending, onCancel, onSubmit }: ProviderConnectFormProps) { + const [region, setRegion] = useState(""); + const [authMethod, setAuthMethod] = useState(BEDROCK_AUTH_ACCESS_KEY); + const [accessKeyId, setAccessKeyId] = useState(""); + const [secretAccessKey, setSecretAccessKey] = useState(""); + const [bearerToken, setBearerToken] = useState(""); + + const canSubmit = (() => { + if (!region) return false; + if (authMethod === BEDROCK_AUTH_ACCESS_KEY) { + return Boolean(accessKeyId && secretAccessKey); + } + if (authMethod === BEDROCK_AUTH_LONG_TERM_API_KEY) { + return Boolean(bearerToken); + } + return true; + })(); + + function handleSubmit() { + const params: Record = { aws_region_name: region }; + if (authMethod === BEDROCK_AUTH_ACCESS_KEY) { + params.aws_access_key_id = accessKeyId; + params.aws_secret_access_key = secretAccessKey; + } else if (authMethod === BEDROCK_AUTH_LONG_TERM_API_KEY) { + params.aws_bearer_token_bedrock = bearerToken; + } + onSubmit({ base_url: null, api_key: null, extra: { litellm_params: params } }); + } + + return ( + <> +
+
+ + +
+
+ + +
+ {authMethod === BEDROCK_AUTH_ACCESS_KEY ? ( + <> +
+ + setAccessKeyId(event.target.value)} + placeholder="AKIAIOSFODNN7EXAMPLE" + /> +
+
+ + setSecretAccessKey(event.target.value)} + placeholder="wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY" + type="password" + /> +
+ + ) : null} + {authMethod === BEDROCK_AUTH_LONG_TERM_API_KEY ? ( +
+ + setBearerToken(event.target.value)} + placeholder="Your long-term API key" + type="password" + /> +
+ ) : null} + {authMethod === BEDROCK_AUTH_IAM ? ( +

+ SurfSense will use the IAM role attached to the environment it's running in to + authenticate. +

+ ) : null} +

+ Add Bedrock model IDs from the provider's settings after connecting. +

+
+ + + ); +} diff --git a/surfsense_web/components/settings/model-connections/connect-fields.tsx b/surfsense_web/components/settings/model-connections/connect-fields.tsx new file mode 100644 index 000000000..af8db7f12 --- /dev/null +++ b/surfsense_web/components/settings/model-connections/connect-fields.tsx @@ -0,0 +1,83 @@ +import { Button } from "@/components/ui/button"; +import { DialogFooter } from "@/components/ui/dialog"; +import { Input } from "@/components/ui/input"; +import { Label } from "@/components/ui/label"; + +interface ApiBaseUrlFieldProps { + value: string; + onChange: (value: string) => void; + optional?: boolean; + /** Placeholder, typically the provider's prefilled default base URL. */ + placeholder?: string; +} + +/** Shared API Base URL input. The prefilled default is passed in via `value`. */ +export function ApiBaseUrlField({ value, onChange, optional, placeholder }: ApiBaseUrlFieldProps) { + return ( +
+ + onChange(event.target.value)} + placeholder={placeholder || "https://api.example.com/v1"} + /> +

+ Local URLs are tested from the backend container, so use host.docker.internal instead of + localhost. +

+
+ ); +} + +interface ApiKeyFieldProps { + value: string; + onChange: (value: string) => void; + label?: string; + placeholder?: string; +} + +/** Shared masked API Key input. */ +export function ApiKeyField({ + value, + onChange, + label = "API Key", + placeholder = "API key", +}: ApiKeyFieldProps) { + return ( +
+ + onChange(event.target.value)} + placeholder={placeholder} + type="password" + /> +
+ ); +} + +interface ConnectFormFooterProps { + onCancel: () => void; + onSubmit: () => void; + canSubmit: boolean; + isPending: boolean; +} + +/** Shared Cancel / Connect footer for every provider connect form. */ +export function ConnectFormFooter({ + onCancel, + onSubmit, + canSubmit, + isPending, +}: ConnectFormFooterProps) { + return ( + + + + + ); +} diff --git a/surfsense_web/components/settings/model-connections/connection-settings-dialog.tsx b/surfsense_web/components/settings/model-connections/connection-settings-dialog.tsx new file mode 100644 index 000000000..f3821af46 --- /dev/null +++ b/surfsense_web/components/settings/model-connections/connection-settings-dialog.tsx @@ -0,0 +1,249 @@ +import { useAtomValue } from "jotai"; +import { Eye, EyeOff, Settings } from "lucide-react"; +import { useState } from "react"; +import { + addManualModelMutationAtom, + bulkUpdateModelsMutationAtom, + discoverConnectionModelsMutationAtom, + updateModelConnectionMutationAtom, + updateModelMutationAtom, + verifyModelConnectionMutationAtom, +} from "@/atoms/model-connections/model-connections-mutation.atoms"; +import { Button } from "@/components/ui/button"; +import { + Dialog, + DialogContent, + DialogDescription, + DialogFooter, + DialogHeader, + DialogTitle, + DialogTrigger, +} from "@/components/ui/dialog"; +import { Input } from "@/components/ui/input"; +import { Label } from "@/components/ui/label"; +import { Separator } from "@/components/ui/separator"; +import type { + ConnectionRead, + ConnectionUpdateRequest, + ModelRead, +} from "@/contracts/types/model-connections.types"; +import { ModelsSelectionPanel } from "./models-selection-panel"; +import { providerIcon } from "./provider-metadata"; + +interface ConnectionSettingsDialogProps { + connection: ConnectionRead; + providerLabel: string; +} + +export function ConnectionSettingsDialog({ + connection, + providerLabel, +}: ConnectionSettingsDialogProps) { + const verifyConnection = useAtomValue(verifyModelConnectionMutationAtom); + const discoverModels = useAtomValue(discoverConnectionModelsMutationAtom); + const updateConnection = useAtomValue(updateModelConnectionMutationAtom); + const addManualModel = useAtomValue(addManualModelMutationAtom); + const updateModel = useAtomValue(updateModelMutationAtom); + const bulkUpdateModels = useAtomValue(bulkUpdateModelsMutationAtom); + + const allowlist = Array.isArray(connection.extra?.model_ids) + ? (connection.extra.model_ids as string[]) + : []; + const [isOpen, setIsOpen] = useState(false); + const [baseUrlDraft, setBaseUrlDraft] = useState(connection.base_url ?? ""); + const [apiKeyDraft, setApiKeyDraft] = useState(""); + const [showApiKey, setShowApiKey] = useState(false); + const [allowlistText, setAllowlistText] = useState(allowlist.join(", ")); + + const isLocal = + connection.provider === "ollama_chat" || + connection.provider === "lm_studio" || + !connection.base_url?.startsWith("https"); + const hasConnectionChanges = + baseUrlDraft.trim() !== (connection.base_url ?? "") || + apiKeyDraft.trim() !== (connection.api_key ?? ""); + + function handleOpenChange(open: boolean) { + setIsOpen(open); + if (open) { + setBaseUrlDraft(connection.base_url ?? ""); + setApiKeyDraft(connection.api_key ?? ""); + setShowApiKey(false); + setAllowlistText(allowlist.join(", ")); + } + } + + function saveConnectionSettings() { + const data: ConnectionUpdateRequest = { + base_url: baseUrlDraft.trim() || null, + }; + + if (apiKeyDraft.trim() !== (connection.api_key ?? "")) { + data.api_key = apiKeyDraft.trim() || null; + } + + updateConnection.mutate( + { id: connection.id, data }, + { + onSuccess: () => setApiKeyDraft(""), + } + ); + } + + function saveAllowlist() { + const ids = allowlistText + .split(",") + .map((value) => value.trim()) + .filter(Boolean); + updateConnection.mutate({ + id: connection.id, + data: { extra: { ...(connection.extra ?? {}), model_ids: ids } }, + }); + } + + function handleToggleModel(model: ModelRead, enabled: boolean) { + updateModel.mutate({ + id: model.id, + data: { enabled }, + }); + } + + function handleBulkToggle(models: ModelRead[], enabled: boolean) { + bulkUpdateModels.mutate({ + connectionId: connection.id, + data: { model_ids: models.map((model) => model.id), enabled }, + }); + } + + return ( + + + + + + +
+ {providerIcon(connection.provider, "size-5")} +
+ + Configure {providerLabel} + + + Manage credentials and choose which models are available from this provider. + +
+
+
+ +
+
+
+ + setBaseUrlDraft(event.target.value)} + placeholder="https://api.example.com/v1" + /> +

+ Leave empty to use the provider default endpoint. +

+
+ +
+ +
+ setApiKeyDraft(event.target.value)} + placeholder={connection.has_api_key ? "Saved API key" : "Paste an API key"} + type={showApiKey ? "text" : "password"} + className="pr-11" + /> + +
+
+ + {!isLocal ? ( +
+ +
+ setAllowlistText(event.target.value)} + placeholder="Comma-separated, e.g. anthropic/claude-sonnet-4-5, google/gemini-2.5-pro" + /> + +
+

+ Leave empty to discover all models. Recommended for providers with large catalogs. +

+
+ ) : null} + + + + discoverModels.mutate(connection.id)} + onAddManual={(modelId) => + addManualModel.mutate({ + connectionId: connection.id, + data: { model_id: modelId }, + }) + } + onToggleModel={handleToggleModel} + onBulkToggle={handleBulkToggle} + /> + + {connection.last_status && connection.last_status !== "OK" ? ( +

+ {connection.last_error || "Could not list models."} Chat may still work; add model + IDs manually if discovery is unavailable. +

+ ) : null} +
+
+ + + + + +
+
+ ); +} diff --git a/surfsense_web/components/settings/model-connections/default-connect-form.tsx b/surfsense_web/components/settings/model-connections/default-connect-form.tsx new file mode 100644 index 000000000..3f261c6b2 --- /dev/null +++ b/surfsense_web/components/settings/model-connections/default-connect-form.tsx @@ -0,0 +1,51 @@ +import { useState } from "react"; +import { ApiBaseUrlField, ApiKeyField, ConnectFormFooter } from "./connect-fields"; +import type { ProviderConnectFormProps } from "./provider-metadata"; + +/** + * Connect form for OpenAI-compatible / native key providers (OpenAI, Anthropic, + * OpenRouter, OpenAI-Compatible, LM Studio, Ollama, …). The base URL is + * prefilled from the provider default. + */ +export function DefaultConnectForm({ + provider, + defaultBaseUrl, + baseUrlRequired, + isPending, + onCancel, + onSubmit, +}: ProviderConnectFormProps) { + const [baseUrl, setBaseUrl] = useState(defaultBaseUrl); + const [apiKey, setApiKey] = useState(""); + const isOllama = provider === "ollama_chat"; + const canSubmit = !(baseUrlRequired && !baseUrl.trim()); + + function handleSubmit() { + onSubmit({ base_url: baseUrl || null, api_key: apiKey || null, extra: {} }); + } + + return ( + <> +
+ + +
+ + + ); +} diff --git a/surfsense_web/components/settings/model-connections/model-utils.ts b/surfsense_web/components/settings/model-connections/model-utils.ts new file mode 100644 index 000000000..1db14b3eb --- /dev/null +++ b/surfsense_web/components/settings/model-connections/model-utils.ts @@ -0,0 +1,25 @@ +import type { ModelRead } from "@/contracts/types/model-connections.types"; + +export type ModelCapabilityFilter = "chat" | "vision" | "image_gen"; + +export const MODEL_CAPABILITY_FILTERS: { key: ModelCapabilityFilter; label: string }[] = [ + { key: "chat", label: "Chat" }, + { key: "vision", label: "Vision" }, + { key: "image_gen", label: "Image" }, +]; + +export function modelLabel(model: ModelRead) { + return model.display_name || model.model_id; +} + +export function capability(model: ModelRead, key: ModelCapabilityFilter) { + if (key === "chat") return Boolean(model.supports_chat); + if (key === "vision") return Boolean(model.supports_image_input); + return Boolean(model.supports_image_generation); +} + +export function capabilityLabels(model: ModelRead) { + return MODEL_CAPABILITY_FILTERS.filter((filter) => capability(model, filter.key)) + .map((filter) => filter.label.toLowerCase()) + .join(", "); +} diff --git a/surfsense_web/components/settings/model-connections/models-selection-panel.tsx b/surfsense_web/components/settings/model-connections/models-selection-panel.tsx new file mode 100644 index 000000000..20fbe862f --- /dev/null +++ b/surfsense_web/components/settings/model-connections/models-selection-panel.tsx @@ -0,0 +1,196 @@ +import { RefreshCcw } from "lucide-react"; +import { useState } from "react"; +import { Badge } from "@/components/ui/badge"; +import { Button } from "@/components/ui/button"; +import { Checkbox } from "@/components/ui/checkbox"; +import { Input } from "@/components/ui/input"; +import type { ModelRead } from "@/contracts/types/model-connections.types"; +import { + capability, + capabilityLabels, + MODEL_CAPABILITY_FILTERS, + type ModelCapabilityFilter, + modelLabel, +} from "./model-utils"; + +interface ModelsSelectionPanelProps { + models: ModelRead[]; + description?: string; + emptyMessage?: string; + manualInputPlaceholder?: string; + refreshLabel?: string; + isRefreshing?: boolean; + isAddingManual?: boolean; + isUpdatingModel?: boolean; + isBulkUpdating?: boolean; + onRefresh?: () => void; + onAddManual?: (modelId: string) => void; + onToggleModel?: (model: ModelRead, enabled: boolean) => void; + onBulkToggle?: (models: ModelRead[], enabled: boolean) => void; +} + +export function ModelsSelectionPanel({ + models, + description = "Select models to make available for this provider.", + emptyMessage = "No models yet. Use the refresh button to discover models or add one manually.", + manualInputPlaceholder = "Add a model ID manually", + refreshLabel = "Refresh models", + isRefreshing = false, + isAddingManual = false, + isUpdatingModel = false, + isBulkUpdating = false, + onRefresh, + onAddManual, + onToggleModel, + onBulkToggle, +}: ModelsSelectionPanelProps) { + const [manualModelId, setManualModelId] = useState(""); + const [modelFilter, setModelFilter] = useState(null); + + const filteredModels = modelFilter + ? models.filter((model) => capability(model, modelFilter)) + : models; + const allFilteredModelsEnabled = + filteredModels.length > 0 && filteredModels.every((model) => model.enabled); + + function addModel() { + const modelId = manualModelId.trim(); + if (!modelId || !onAddManual) return; + onAddManual(modelId); + setManualModelId(""); + } + + function toggleFilteredModels() { + const nextEnabled = !allFilteredModelsEnabled; + const changedModels = filteredModels.filter((model) => model.enabled !== nextEnabled); + if (changedModels.length === 0) return; + onBulkToggle?.(changedModels, nextEnabled); + } + + return ( +
+
+
+
Models
+

{description}

+
+
+ + {onRefresh ? ( + + ) : null} +
+
+ + {onAddManual ? ( +
+ setManualModelId(event.target.value)} + onKeyDown={(event) => { + if (event.key === "Enter") { + event.preventDefault(); + addModel(); + } + }} + placeholder={manualInputPlaceholder} + /> + +
+ ) : null} + + {models.length > 0 ? ( +
+ Filter models + {MODEL_CAPABILITY_FILTERS.map((filter) => { + const count = models.filter((model) => capability(model, filter.key)).length; + const isActive = modelFilter === filter.key; + + return ( + + ); + })} +
+ ) : null} + +
+ {models.length === 0 ? ( +
+ {emptyMessage} +
+ ) : null} + {filteredModels.length === 0 && modelFilter ? ( +
+ No{" "} + {MODEL_CAPABILITY_FILTERS.find( + (filter) => filter.key === modelFilter + )?.label.toLowerCase()}{" "} + models found on this connection. +
+ ) : null} +
+ {filteredModels.map((model) => ( +
+ onToggleModel?.(model, checked === true)} + disabled={!onToggleModel || isUpdatingModel} + /> +
+
+ {modelLabel(model)} + {model.source === "MANUAL" ? ( + + manual + + ) : null} +
+
+ {capabilityLabels(model) || "No discovered capabilities"} +
+
+
+ ))} +
+
+
+ ); +} diff --git a/surfsense_web/components/settings/model-connections/provider-connect-dialog.tsx b/surfsense_web/components/settings/model-connections/provider-connect-dialog.tsx new file mode 100644 index 000000000..871a66cc5 --- /dev/null +++ b/surfsense_web/components/settings/model-connections/provider-connect-dialog.tsx @@ -0,0 +1,154 @@ +import { Button } from "@/components/ui/button"; +import { + Dialog, + DialogContent, + DialogDescription, + DialogFooter, + DialogHeader, + DialogTitle, +} from "@/components/ui/dialog"; +import type { + ConnectionRead, + ModelProviderRead, + ModelRead, +} from "@/contracts/types/model-connections.types"; +import { AzureConnectForm } from "./azure-connect-form"; +import { BedrockConnectForm } from "./bedrock-connect-form"; +import { DefaultConnectForm } from "./default-connect-form"; +import { ModelsSelectionPanel } from "./models-selection-panel"; +import { + type ConnectionDraft, + type ProviderConnectFormProps, + providerDefaultBaseUrl, + providerDisplay, + providerIcon, +} from "./provider-metadata"; +import { VertexConnectForm } from "./vertex-connect-form"; + +interface ProviderConnectDialogProps { + open: boolean; + onOpenChange: (open: boolean) => void; + provider: string; + selectedProvider?: ModelProviderRead; + isPending: boolean; + onSubmit: (draft: ConnectionDraft) => void; + connectedConnection?: ConnectionRead | null; + connectModels?: ModelRead[]; + isDiscoveringModels?: boolean; + isAddingManualModel?: boolean; + isUpdatingModel?: boolean; + isBulkUpdatingModels?: boolean; + onRefreshModels?: () => void; + onAddManualModel?: (modelId: string) => void; + onToggleModel?: (model: ModelRead, enabled: boolean) => void; + onBulkToggleModels?: (models: ModelRead[], enabled: boolean) => void; + onDone?: () => void; +} + +/** + * Shared dialog shell for the "Add Provider" flow. It owns the header and routes + * to the provider-specific connect form. Forms remount on open (Radix unmounts + * closed content), so each gets fresh, prefilled state. + */ +export function ProviderConnectDialog({ + open, + onOpenChange, + provider, + selectedProvider, + isPending, + onSubmit, + connectedConnection, + connectModels = [], + isDiscoveringModels = false, + isAddingManualModel = false, + isUpdatingModel = false, + isBulkUpdatingModels = false, + onRefreshModels, + onAddManualModel, + onToggleModel, + onBulkToggleModels, + onDone, +}: ProviderConnectDialogProps) { + const meta = providerDisplay(provider); + const isModelSelectionStep = Boolean(connectedConnection); + + const formProps: ProviderConnectFormProps = { + provider, + defaultBaseUrl: providerDefaultBaseUrl(provider, selectedProvider?.default_base_url), + baseUrlRequired: Boolean(selectedProvider?.base_url_required), + isPending, + onCancel: () => onOpenChange(false), + onSubmit, + }; + + return ( + + + +
+ {providerIcon(provider, "size-5")} +
+ + {isModelSelectionStep ? `Select ${meta.name} models` : `Connect ${meta.name}`} + + + {isModelSelectionStep + ? selectedProvider?.discovery === "static" + ? "Choose from known model IDs or add one manually." + : "Choose which discovered models should be available in this search space." + : meta.subtitle} + +
+
+
+ {isModelSelectionStep ? ( + <> +
+ +
+ + + + + ) : ( +
+ {provider === "azure" ? ( + + ) : provider === "bedrock" ? ( + + ) : provider === "vertex_ai" ? ( + + ) : ( + + )} +
+ )} +
+
+ ); +} diff --git a/surfsense_web/components/settings/model-connections/provider-metadata.tsx b/surfsense_web/components/settings/model-connections/provider-metadata.tsx new file mode 100644 index 000000000..0ca8ae419 --- /dev/null +++ b/surfsense_web/components/settings/model-connections/provider-metadata.tsx @@ -0,0 +1,139 @@ +import { getProviderIcon } from "@/lib/provider-icons"; + +export const PROVIDER_ORDER = [ + "openai", + "anthropic", + "vertex_ai", + "bedrock", + "azure", + "openrouter", + "ollama_chat", + "lm_studio", + "openai_compatible", +]; + +export const PROVIDER_DISPLAY: Record< + string, + { name: string; subtitle: string; iconKey?: string; defaultBaseUrl?: string } +> = { + anthropic: { + name: "Claude", + subtitle: "Anthropic", + iconKey: "anthropic", + defaultBaseUrl: "https://api.anthropic.com/v1", + }, + azure: { name: "Azure OpenAI", subtitle: "Microsoft Azure", iconKey: "azure_openai" }, + bedrock: { name: "Amazon Bedrock", subtitle: "AWS", iconKey: "bedrock" }, + lm_studio: { name: "LM Studio", subtitle: "LM Studio", iconKey: "custom" }, + ollama_chat: { name: "Ollama", subtitle: "Ollama", iconKey: "ollama" }, + openai: { + name: "GPT", + subtitle: "OpenAI", + iconKey: "openai", + defaultBaseUrl: "https://api.openai.com/v1", + }, + openai_compatible: { + name: "OpenAI-Compatible", + subtitle: "OpenAI-compatible endpoint", + iconKey: "custom", + }, + openrouter: { + name: "OpenRouter", + subtitle: "OpenRouter", + iconKey: "openrouter", + defaultBaseUrl: "https://openrouter.ai/api/v1", + }, + vertex_ai: { name: "Gemini", subtitle: "Google Cloud Vertex AI", iconKey: "vertex_ai" }, +}; + +export function providerDisplay(provider: string) { + const fallback = provider + .split("_") + .filter(Boolean) + .map((part) => part.charAt(0).toUpperCase() + part.slice(1)) + .join(" "); + + return ( + PROVIDER_DISPLAY[provider] ?? { + name: fallback || provider, + subtitle: provider, + iconKey: provider, + } + ); +} + +export function providerIcon(provider: string, className = "size-4") { + return getProviderIcon(providerDisplay(provider).iconKey ?? provider, { className }); +} + +export function providerDefaultBaseUrl(provider: string, registryDefault?: string | null) { + return registryDefault ?? PROVIDER_DISPLAY[provider]?.defaultBaseUrl ?? ""; +} + +export const AWS_REGION_OPTIONS = [ + "us-east-1", + "us-east-2", + "us-west-2", + "us-gov-east-1", + "us-gov-west-1", + "ap-northeast-1", + "ap-south-1", + "ap-southeast-1", + "ap-southeast-2", + "ap-east-1", + "ca-central-1", + "eu-central-1", + "eu-west-2", +]; + +export const VERTEX_DEFAULT_LOCATION = "global"; + +export const BEDROCK_AUTH_IAM = "iam"; +export const BEDROCK_AUTH_ACCESS_KEY = "access_key"; +export const BEDROCK_AUTH_LONG_TERM_API_KEY = "long_term_api_key"; + +export const VERTEX_AUTH_SERVICE_ACCOUNT = "service_account_json"; +export const VERTEX_AUTH_WORKLOAD_IDENTITY = "workload_identity"; + +// Mirrors Onyx's Azure "Target URI" parser: the user pastes the full endpoint +// (e.g. https://res.cognitiveservices.azure.com/openai/deployments//chat/completions?api-version=) +// which we split into api base (origin), api version, and deployment name. +export function parseAzureTargetUri(rawUri: string) { + try { + const url = new URL(rawUri); + const deploymentMatch = url.pathname.match(/\/openai\/deployments\/([^/]+)/i); + return { + origin: url.origin, + apiVersion: url.searchParams.get("api-version")?.trim() ?? "", + deploymentName: deploymentMatch?.[1] ? deploymentMatch[1].toLowerCase() : "", + isResponsesPath: /\/openai\/responses/i.test(url.pathname), + }; + } catch { + return null; + } +} + +export function isValidAzureTargetUri(rawUri: string) { + const parsed = parseAzureTargetUri(rawUri); + if (!parsed) return false; + return Boolean(parsed.apiVersion) && (Boolean(parsed.deploymentName) || parsed.isResponsesPath); +} + +/** Connection payload produced by a provider connect form. */ +export interface ConnectionDraft { + base_url: string | null; + api_key: string | null; + extra: Record; + /** Model id to seed after creation (providers without discovery, e.g. Azure). */ + seedModelId?: string; +} + +/** Props shared by every provider-specific connect form. */ +export interface ProviderConnectFormProps { + provider: string; + defaultBaseUrl: string; + baseUrlRequired: boolean; + isPending: boolean; + onCancel: () => void; + onSubmit: (draft: ConnectionDraft) => void; +} diff --git a/surfsense_web/components/settings/model-connections/vertex-connect-form.tsx b/surfsense_web/components/settings/model-connections/vertex-connect-form.tsx new file mode 100644 index 000000000..096d3df2e --- /dev/null +++ b/surfsense_web/components/settings/model-connections/vertex-connect-form.tsx @@ -0,0 +1,127 @@ +import { useState } from "react"; +import { Input } from "@/components/ui/input"; +import { Label } from "@/components/ui/label"; +import { + Select, + SelectContent, + SelectItem, + SelectTrigger, + SelectValue, +} from "@/components/ui/select"; +import { ConnectFormFooter } from "./connect-fields"; +import { + type ProviderConnectFormProps, + VERTEX_AUTH_SERVICE_ACCOUNT, + VERTEX_AUTH_WORKLOAD_IDENTITY, + VERTEX_DEFAULT_LOCATION, +} from "./provider-metadata"; + +/** + * Google Vertex AI (Gemini) connect form. Service-account auth uploads a + * credentials JSON file (read into a string); workload identity collects a + * project id. Credentials ride along in `extra.litellm_params`. + */ +export function VertexConnectForm({ isPending, onCancel, onSubmit }: ProviderConnectFormProps) { + const [authMethod, setAuthMethod] = useState(VERTEX_AUTH_SERVICE_ACCOUNT); + const [location, setLocation] = useState(VERTEX_DEFAULT_LOCATION); + const [credentials, setCredentials] = useState(""); + const [project, setProject] = useState(""); + + const canSubmit = + authMethod === VERTEX_AUTH_SERVICE_ACCOUNT ? Boolean(credentials) : Boolean(project); + + async function handleCredentialsFile(file: File | undefined) { + if (!file) return; + setCredentials(await file.text()); + } + + function handleSubmit() { + const params: Record = {}; + if (location) params.vertex_location = location; + if (authMethod === VERTEX_AUTH_SERVICE_ACCOUNT) { + if (credentials) params.vertex_credentials = credentials; + } else if (project) { + params.vertex_project = project; + } + onSubmit({ base_url: null, api_key: null, extra: { litellm_params: params } }); + } + + return ( + <> +
+
+ + +
+
+ + setLocation(event.target.value)} + placeholder={VERTEX_DEFAULT_LOCATION} + /> +

+ Region where your Google Vertex AI models are hosted. +

+
+ {authMethod === VERTEX_AUTH_SERVICE_ACCOUNT ? ( +
+ + handleCredentialsFile(event.target.files?.[0])} + /> + +

+ {credentials + ? "Credentials file loaded." + : "Attach your service account key JSON from Google Cloud."} +

+
+ ) : ( +
+ + setProject(event.target.value)} + placeholder="my-vertex-project" + /> +

+ The GCP project where Vertex AI is enabled. +

+
+ )} +

+ Add Vertex AI model IDs from the provider's settings after connecting. +

+
+ + + ); +} diff --git a/surfsense_web/lib/provider-icons.tsx b/surfsense_web/lib/provider-icons.tsx index e63c5eb2f..3bb310904 100644 --- a/surfsense_web/lib/provider-icons.tsx +++ b/surfsense_web/lib/provider-icons.tsx @@ -1,4 +1,4 @@ -import { Bot, Shuffle } from "lucide-react"; +import { Cpu, Shuffle } from "lucide-react"; import { Ai21Icon, AnthropicIcon, @@ -72,7 +72,7 @@ export function getProviderIcon( case "COMETAPI": return ; case "CUSTOM": - return ; + return ; case "DATABRICKS": return ; case "DEEPINFRA": @@ -122,6 +122,6 @@ export function getProviderIcon( case "ZHIPU": return ; default: - return ; + return ; } } From c6e71c851cdba5323f6b84b823650fd0e079e07a Mon Sep 17 00:00:00 2001 From: CREDO23 Date: Fri, 12 Jun 2026 18:52:45 +0200 Subject: [PATCH 086/212] feat(chunks): add explicit position column with backfill migration Chunk ids stop reflecting document order once incremental re-indexing keeps unchanged rows across edits. Backfill preserves the historical id ordering so behavior is identical on day one. --- .../versions/162_add_chunk_position.py | 51 +++++++++++++++++++ surfsense_backend/app/db.py | 8 ++- 2 files changed, 58 insertions(+), 1 deletion(-) create mode 100644 surfsense_backend/alembic/versions/162_add_chunk_position.py diff --git a/surfsense_backend/alembic/versions/162_add_chunk_position.py b/surfsense_backend/alembic/versions/162_add_chunk_position.py new file mode 100644 index 000000000..cb240e3ef --- /dev/null +++ b/surfsense_backend/alembic/versions/162_add_chunk_position.py @@ -0,0 +1,51 @@ +"""add chunks.position for explicit document order + +Incremental re-indexing keeps unchanged chunk rows, so auto-increment ids no +longer reflect document order. Backfill preserves the historical id ordering. + +Revision ID: 162 +Revises: 161 +""" + +from collections.abc import Sequence + +from alembic import op + +revision: str = "162" +down_revision: str | None = "161" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def upgrade() -> None: + op.execute( + "ALTER TABLE chunks ADD COLUMN IF NOT EXISTS position INTEGER NOT NULL DEFAULT 0;" + ) + + # Backfill: document order so far has been the insertion order (id). + op.execute( + """ + UPDATE chunks + SET position = numbered.rn + FROM ( + SELECT id, + ROW_NUMBER() OVER (PARTITION BY document_id ORDER BY id) - 1 AS rn + FROM chunks + ) AS numbered + WHERE chunks.id = numbered.id; + """ + ) + + op.execute( + "CREATE INDEX IF NOT EXISTS ix_chunks_position ON chunks(position);" + ) + op.execute( + "CREATE INDEX IF NOT EXISTS ix_chunks_document_id_position " + "ON chunks(document_id, position);" + ) + + +def downgrade() -> None: + op.execute("DROP INDEX IF EXISTS ix_chunks_document_id_position;") + op.execute("DROP INDEX IF EXISTS ix_chunks_position;") + op.execute("ALTER TABLE chunks DROP COLUMN IF EXISTS position;") diff --git a/surfsense_backend/app/db.py b/surfsense_backend/app/db.py index 9ec13f4e2..8d110bbf1 100644 --- a/surfsense_backend/app/db.py +++ b/surfsense_backend/app/db.py @@ -1484,7 +1484,10 @@ class Document(BaseModel, TimestampMixin): created_by = relationship("User", back_populates="documents") connector = relationship("SearchSourceConnector", back_populates="documents") chunks = relationship( - "Chunk", back_populates="document", cascade="all, delete-orphan" + "Chunk", + back_populates="document", + cascade="all, delete-orphan", + order_by="Chunk.position", ) # Original upload + future derived artifacts (redacted, filled-form). # Model lives in app.file_storage.persistence to keep that feature cohesive. @@ -1520,6 +1523,9 @@ class Chunk(BaseModel, TimestampMixin): content = Column(Text, nullable=False) embedding = Column(Vector(config.embedding_model_instance.dimension)) + # Explicit document order; ids don't follow it since incremental + # re-indexing keeps unchanged rows across edits. + position = Column(Integer, nullable=False, server_default="0", index=True) document_id = Column( Integer, From f82dedf712862b96e5748cbc3ae539ae5a071e0c Mon Sep 17 00:00:00 2001 From: CREDO23 Date: Fri, 12 Jun 2026 18:52:46 +0200 Subject: [PATCH 087/212] feat(indexing): add pure chunk reconciler for content-addressed diffs Greedy multiset match on chunk text decides which rows keep their embeddings, which texts need embedding, and which rows are deleted. No DB, no embeddings; fully unit-tested (reuse, head insert, middle edit, deletion, duplicates, reorder, full rewrite). --- .../app/indexing_pipeline/chunk_reconciler.py | 56 +++++++++++ .../test_chunk_reconciler.py | 94 +++++++++++++++++++ 2 files changed, 150 insertions(+) create mode 100644 surfsense_backend/app/indexing_pipeline/chunk_reconciler.py create mode 100644 surfsense_backend/tests/unit/indexing_pipeline/test_chunk_reconciler.py diff --git a/surfsense_backend/app/indexing_pipeline/chunk_reconciler.py b/surfsense_backend/app/indexing_pipeline/chunk_reconciler.py new file mode 100644 index 000000000..9354aeb9f --- /dev/null +++ b/surfsense_backend/app/indexing_pipeline/chunk_reconciler.py @@ -0,0 +1,56 @@ +"""Diff a document's existing chunk rows against its freshly chunked texts. + +Embeddings are a pure function of chunk text, so a row whose content reappears +in the new chunking keeps its embedding (and its HNSW/GIN index entries); only +genuinely new texts are embedded and only vanished rows are deleted. Matching +is a greedy multiset match on content in document order, so duplicate +boilerplate chunks pair up one-to-one and reordered chunks become cheap +position updates instead of delete+reinsert. +""" + +from __future__ import annotations + +from collections import defaultdict, deque +from dataclasses import dataclass + + +@dataclass(frozen=True, slots=True) +class ExistingChunk: + id: int + content: str + position: int + + +@dataclass(frozen=True, slots=True) +class ChunkPlan: + """The minimal set of writes that turns the stored chunks into the new ones. + + ``reused`` holds only kept rows whose position actually changed; rows that + match in place need no write at all. Kept-row count (for metrics) is + ``len(existing) - len(to_delete)``. + """ + + reused: list[tuple[int, int]] # (existing_chunk_id, new_position) + to_embed: list[tuple[int, str]] # (new_position, text) + to_delete: list[int] # existing chunk ids + + +def reconcile(existing: list[ExistingChunk], new_texts: list[str]) -> ChunkPlan: + available: dict[str, deque[ExistingChunk]] = defaultdict(deque) + for chunk in sorted(existing, key=lambda c: c.position): + available[chunk.content].append(chunk) + + reused: list[tuple[int, int]] = [] + to_embed: list[tuple[int, str]] = [] + + for new_position, text in enumerate(new_texts): + matches = available.get(text) + if matches: + chunk = matches.popleft() + if chunk.position != new_position: + reused.append((chunk.id, new_position)) + else: + to_embed.append((new_position, text)) + + to_delete = [chunk.id for queue in available.values() for chunk in queue] + return ChunkPlan(reused=reused, to_embed=to_embed, to_delete=to_delete) diff --git a/surfsense_backend/tests/unit/indexing_pipeline/test_chunk_reconciler.py b/surfsense_backend/tests/unit/indexing_pipeline/test_chunk_reconciler.py new file mode 100644 index 000000000..7effce840 --- /dev/null +++ b/surfsense_backend/tests/unit/indexing_pipeline/test_chunk_reconciler.py @@ -0,0 +1,94 @@ +"""reconcile(): diff existing chunk rows against new chunk texts. + +The reconciler decides which rows (and embeddings) survive an edit, which texts +must be embedded, and which rows go away -- purely from content, no DB. +""" + +from __future__ import annotations + +from app.indexing_pipeline.chunk_reconciler import ExistingChunk, reconcile + + +def _existing(*contents: str) -> list[ExistingChunk]: + return [ + ExistingChunk(id=i + 1, content=text, position=i) + for i, text in enumerate(contents) + ] + + +def test_identical_content_keeps_every_row_untouched(): + plan = reconcile(_existing("alpha", "beta", "gamma"), ["alpha", "beta", "gamma"]) + + assert plan.to_embed == [] + assert plan.to_delete == [] + assert plan.reused == [] + + +def test_head_insert_embeds_only_the_new_chunk_and_shifts_the_rest(): + plan = reconcile(_existing("alpha", "beta"), ["intro", "alpha", "beta"]) + + assert plan.to_embed == [(0, "intro")] + assert plan.to_delete == [] + # alpha: position 0 -> 1, beta: 1 -> 2; embeddings untouched. + assert plan.reused == [(1, 1), (2, 2)] + + +def test_middle_edit_swaps_exactly_one_chunk(): + plan = reconcile( + _existing("alpha", "beta", "gamma"), ["alpha", "beta EDITED", "gamma"] + ) + + assert plan.to_embed == [(1, "beta EDITED")] + assert plan.to_delete == [2] + # Neighbours did not move, so no position writes at all. + assert plan.reused == [] + + +def test_removed_chunk_is_deleted_and_followers_shift_up(): + plan = reconcile(_existing("alpha", "beta", "gamma"), ["alpha", "gamma"]) + + assert plan.to_embed == [] + assert plan.to_delete == [2] + assert plan.reused == [(3, 1)] + + +def test_duplicate_texts_pair_up_one_to_one(): + # Two identical boilerplate chunks, only one survives the edit: exactly one + # row is kept and exactly one is deleted -- never both kept or both dropped. + plan = reconcile(_existing("boiler", "boiler", "body"), ["boiler", "body"]) + + assert plan.to_embed == [] + assert plan.to_delete == [2] + assert plan.reused == [(3, 1)] + + +def test_duplicate_growth_embeds_only_the_extra_copy(): + plan = reconcile(_existing("boiler", "body"), ["boiler", "boiler", "body"]) + + assert plan.to_embed == [(1, "boiler")] + assert plan.to_delete == [] + assert plan.reused == [(2, 2)] + + +def test_reorder_becomes_position_updates_with_no_embedding(): + plan = reconcile(_existing("alpha", "beta"), ["beta", "alpha"]) + + assert plan.to_embed == [] + assert plan.to_delete == [] + assert sorted(plan.reused) == [(1, 1), (2, 0)] + + +def test_full_rewrite_replaces_everything(): + plan = reconcile(_existing("alpha", "beta"), ["new one", "new two"]) + + assert plan.to_embed == [(0, "new one"), (1, "new two")] + assert sorted(plan.to_delete) == [1, 2] + assert plan.reused == [] + + +def test_no_existing_chunks_embeds_all(): + plan = reconcile([], ["alpha", "beta"]) + + assert plan.to_embed == [(0, "alpha"), (1, "beta")] + assert plan.to_delete == [] + assert plan.reused == [] From 8d413ea5c2b644e86de2f626b8ba461cfe4baa42 Mon Sep 17 00:00:00 2001 From: CREDO23 Date: Fri, 12 Jun 2026 18:52:57 +0200 Subject: [PATCH 088/212] refactor(indexing): expose chunk_markdown and embed_batch helpers Split _compute so the incremental edit path can reuse the exact same chunker selection and embedding entry points (and their test patch targets) without going through the doc-level cache. --- .../cache/cached_indexing.py | 28 ++++++++++++------- 1 file changed, 18 insertions(+), 10 deletions(-) diff --git a/surfsense_backend/app/indexing_pipeline/cache/cached_indexing.py b/surfsense_backend/app/indexing_pipeline/cache/cached_indexing.py index c93f2f133..95321a229 100644 --- a/surfsense_backend/app/indexing_pipeline/cache/cached_indexing.py +++ b/surfsense_backend/app/indexing_pipeline/cache/cached_indexing.py @@ -58,7 +58,9 @@ async def build_chunk_embeddings( cached = await _recall(key) if cached is not None: metrics.record_embedding_cache_lookup( - embedding_model=key.embedding_model, chunker_kind=chunker_kind, outcome="hit" + embedding_model=key.embedding_model, + chunker_kind=chunker_kind, + outcome="hit", ) logger.debug("Embedding cache hit for %s", key.markdown_sha256) return cached.summary_embedding, [(c.text, c.embedding) for c in cached.chunks] @@ -73,18 +75,24 @@ async def build_chunk_embeddings( return summary_embedding, chunk_pairs +async def chunk_markdown(markdown: str, *, use_code_chunker: bool) -> list[str]: + """Chunk markdown into ordered texts with the pipeline's chunker selection.""" + if use_code_chunker: + return await asyncio.to_thread(chunk_text, markdown, use_code_chunker=True) + # Table-aware hybrid chunker keeps Markdown tables intact (issue #1334). + return await asyncio.to_thread(chunk_text_hybrid, markdown) + + +async def embed_batch(texts: list[str]) -> list[np.ndarray]: + """Embed texts in one batch off the event loop.""" + return await asyncio.to_thread(embed_texts, texts) + + async def _compute( markdown: str, *, use_code_chunker: bool ) -> tuple[np.ndarray, list[ChunkPair]]: - if use_code_chunker: - chunk_texts = await asyncio.to_thread( - chunk_text, markdown, use_code_chunker=True - ) - else: - # Table-aware hybrid chunker keeps Markdown tables intact (issue #1334). - chunk_texts = await asyncio.to_thread(chunk_text_hybrid, markdown) - - embeddings = await asyncio.to_thread(embed_texts, [markdown, *chunk_texts]) + chunk_texts = await chunk_markdown(markdown, use_code_chunker=use_code_chunker) + embeddings = await embed_batch([markdown, *chunk_texts]) summary_embedding, *chunk_embeddings = embeddings return summary_embedding, list(zip(chunk_texts, chunk_embeddings, strict=False)) From fd495e1b2f69034e038798a291300b0d2fbce7b2 Mon Sep 17 00:00:00 2001 From: CREDO23 Date: Fri, 12 Jun 2026 18:52:57 +0200 Subject: [PATCH 089/212] feat(observability): add chunk reconcile metric and kill-switch flag surfsense.indexing.reconcile.chunks counts reused/embedded/deleted chunks per re-index. CHUNK_RECONCILE_ENABLED (default on) falls back to delete-all + full re-embed if the diff path ever misbehaves. --- surfsense_backend/.env.example | 5 ++++ surfsense_backend/app/config/__init__.py | 7 ++++++ .../app/observability/metrics.py | 23 +++++++++++++++++++ 3 files changed, 35 insertions(+) diff --git a/surfsense_backend/.env.example b/surfsense_backend/.env.example index ac289c5a6..1e09b266a 100644 --- a/surfsense_backend/.env.example +++ b/surfsense_backend/.env.example @@ -342,6 +342,11 @@ EMBEDDING_CACHE_ENABLED=false # Rows deleted per eviction pass. # EMBEDDING_CACHE_EVICTION_BATCH=500 +# Incremental re-indexing: on document edits, keep chunks whose text is +# unchanged (reusing their embeddings) and embed only new/changed ones. +# Set to false to fall back to delete-all + full re-embed (kill switch). +# CHUNK_RECONCILE_ENABLED=true + # Daytona Sandbox (isolated code execution) # DAYTONA_SANDBOX_ENABLED=FALSE # DAYTONA_API_KEY=your-daytona-api-key diff --git a/surfsense_backend/app/config/__init__.py b/surfsense_backend/app/config/__init__.py index 549252cec..c242419f6 100644 --- a/surfsense_backend/app/config/__init__.py +++ b/surfsense_backend/app/config/__init__.py @@ -979,6 +979,13 @@ class Config: os.getenv("EMBEDDING_CACHE_EVICTION_BATCH", "500") ) + # Incremental re-indexing: on document edits, keep chunk rows whose text is + # unchanged (reusing their embeddings) and embed only new/changed chunks. + # Kill switch -- disabling falls back to delete-all + full re-embed. + CHUNK_RECONCILE_ENABLED = ( + os.getenv("CHUNK_RECONCILE_ENABLED", "true").strip().lower() == "true" + ) + # Proxy provider selection. Maps to a ProxyProvider implementation registered # in app/utils/proxy/registry.py. Add new vendors there and switch via this var. PROXY_PROVIDER = os.getenv("PROXY_PROVIDER", "anonymous_proxies") diff --git a/surfsense_backend/app/observability/metrics.py b/surfsense_backend/app/observability/metrics.py index 94bb55740..ade43ab01 100644 --- a/surfsense_backend/app/observability/metrics.py +++ b/surfsense_backend/app/observability/metrics.py @@ -321,6 +321,17 @@ def _embedding_cache_evictions(): ) +@lru_cache(maxsize=1) +def _chunk_reconcile_chunks(): + return _get_meter().create_counter( + "surfsense.indexing.reconcile.chunks", + description=( + "Chunks handled by incremental re-indexing, by outcome " + "(reused/embedded/deleted)." + ), + ) + + @lru_cache(maxsize=1) def _celery_heartbeat_refreshes(): return _get_meter().create_counter( @@ -746,6 +757,17 @@ def record_embedding_cache_eviction(count: int, *, phase: str) -> None: _add(_embedding_cache_evictions(), count, {"phase": phase}) +def record_chunk_reconcile(*, reused: int, embedded: int, deleted: int) -> None: + """Record an incremental re-index: how many chunks were kept vs recomputed.""" + for outcome, count in ( + ("reused", reused), + ("embedded", embedded), + ("deleted", deleted), + ): + if count > 0: + _add(_chunk_reconcile_chunks(), count, {"outcome": outcome}) + + def record_celery_heartbeat_refresh(*, heartbeat_type: str) -> None: _add(_celery_heartbeat_refreshes(), 1, {"heartbeat.type": heartbeat_type}) @@ -939,6 +961,7 @@ __all__ = [ "record_celery_queue_latency", "record_chat_request_duration", "record_chat_request_outcome", + "record_chunk_reconcile", "record_compaction_run", "record_connector_sync_duration", "record_connector_sync_outcome", From 7d55aaf2c183b37b715be3c9f4d15cff5156e04a Mon Sep 17 00:00:00 2001 From: CREDO23 Date: Fri, 12 Jun 2026 18:53:08 +0200 Subject: [PATCH 090/212] feat(indexing): reconcile chunks incrementally on re-index index() now loads existing rows and applies a content diff instead of delete-all/reinsert-all: unchanged chunks keep their rows and embeddings (zero HNSW/GIN churn), moved chunks get a position-only UPDATE, and only new texts are embedded, batched with the summary embedding. First index keeps the cache-aware build_chunk_embeddings path. --- .../indexing_pipeline_service.py | 119 +++++++++++++++--- 1 file changed, 101 insertions(+), 18 deletions(-) diff --git a/surfsense_backend/app/indexing_pipeline/indexing_pipeline_service.py b/surfsense_backend/app/indexing_pipeline/indexing_pipeline_service.py index 271b3ee03..224eb0f5d 100644 --- a/surfsense_backend/app/indexing_pipeline/indexing_pipeline_service.py +++ b/surfsense_backend/app/indexing_pipeline/indexing_pipeline_service.py @@ -8,7 +8,7 @@ from collections.abc import Awaitable, Callable from dataclasses import dataclass, field from datetime import UTC, datetime -from sqlalchemy import delete, select +from sqlalchemy import delete, select, update from sqlalchemy.exc import IntegrityError from sqlalchemy.ext.asyncio import AsyncSession @@ -20,6 +20,8 @@ from app.db import ( DocumentType, ) from app.indexing_pipeline.cache import build_chunk_embeddings +from app.indexing_pipeline.cache.cached_indexing import chunk_markdown, embed_batch +from app.indexing_pipeline.chunk_reconciler import ExistingChunk, reconcile from app.indexing_pipeline.connector_document import ConnectorDocument from app.indexing_pipeline.document_hashing import ( compute_content_hash, @@ -379,39 +381,34 @@ class IndexingPipelineService: content = connector_doc.source_markdown - await self.session.execute( - delete(Chunk).where(Chunk.document_id == document.id) - ) - t_step = time.perf_counter() - summary_embedding, chunk_pairs = await build_chunk_embeddings( - content, - use_code_chunker=connector_doc.should_use_code_chunker, - ) - - chunks = [ - Chunk(content=text, embedding=emb) for text, emb in chunk_pairs - ] + existing = await self._load_existing_chunks(document.id) + if existing and self._reconcile_enabled(): + chunk_count = await self._reindex_incrementally( + document, content, connector_doc, existing + ) + else: + chunk_count = await self._reindex_from_scratch( + document, content, connector_doc + ) perf.info( "[indexing] chunk+embed doc=%d chunks=%d in %.3fs", document.id, - len(chunks), + chunk_count, time.perf_counter() - t_step, ) document.content = content - document.embedding = summary_embedding - attach_chunks_to_document(document, chunks) document.updated_at = datetime.now(UTC) document.status = DocumentStatus.ready() await self.session.commit() perf.info( "[indexing] index TOTAL doc=%d chunks=%d in %.3fs", document.id, - len(chunks), + chunk_count, time.perf_counter() - t_index, ) - log_index_success(ctx, chunk_count=len(chunks)) + log_index_success(ctx, chunk_count=chunk_count) outcome_status = "success" await self._enqueue_ai_sort_if_enabled(document) @@ -468,6 +465,92 @@ class IndexingPipelineService: persist_span_cm.__exit__(*sys.exc_info()) return document + @staticmethod + def _reconcile_enabled() -> bool: + from app.config import config + + return config.CHUNK_RECONCILE_ENABLED + + async def _load_existing_chunks(self, document_id: int) -> list[ExistingChunk]: + result = await self.session.execute( + select(Chunk.id, Chunk.content, Chunk.position).where( + Chunk.document_id == document_id + ) + ) + return [ + ExistingChunk(id=row.id, content=row.content, position=row.position) + for row in result + ] + + async def _reindex_from_scratch( + self, document: Document, content: str, connector_doc: ConnectorDocument + ) -> int: + """First index (or kill-switched re-index): cache-aware full chunk+embed.""" + await self.session.execute( + delete(Chunk).where(Chunk.document_id == document.id) + ) + + summary_embedding, chunk_pairs = await build_chunk_embeddings( + content, + use_code_chunker=connector_doc.should_use_code_chunker, + ) + + chunks = [ + Chunk(content=text, embedding=emb, position=i) + for i, (text, emb) in enumerate(chunk_pairs) + ] + document.embedding = summary_embedding + attach_chunks_to_document(document, chunks) + return len(chunks) + + async def _reindex_incrementally( + self, + document: Document, + content: str, + connector_doc: ConnectorDocument, + existing: list[ExistingChunk], + ) -> int: + """Edit path: keep rows whose text survived, embed only new texts. + + Unchanged rows keep their embedding and their HNSW/GIN index entries; + moved rows get a position-only UPDATE, which touches neither index. + """ + new_texts = await chunk_markdown( + content, use_code_chunker=connector_doc.should_use_code_chunker + ) + plan = reconcile(existing, new_texts) + + # One batch: the document-level summary vector plus the missing chunks. + embeddings = await embed_batch([content, *[t for _, t in plan.to_embed]]) + summary_embedding, *new_embeddings = embeddings + + if plan.reused: + await self.session.execute( + update(Chunk), + [{"id": cid, "position": pos} for cid, pos in plan.reused], + ) + if plan.to_delete: + await self.session.execute( + delete(Chunk).where(Chunk.id.in_(plan.to_delete)) + ) + self.session.add_all( + Chunk( + content=text, + embedding=emb, + position=pos, + document_id=document.id, + ) + for (pos, text), emb in zip(plan.to_embed, new_embeddings, strict=True) + ) + document.embedding = summary_embedding + + ot_metrics.record_chunk_reconcile( + reused=len(existing) - len(plan.to_delete), + embedded=len(plan.to_embed), + deleted=len(plan.to_delete), + ) + return len(new_texts) + async def _enqueue_ai_sort_if_enabled(self, document: Document) -> None: """Fire-and-forget: enqueue incremental AI sort if the search space has it enabled.""" try: From 5a71769dba8b678baba29bfef0a7fc0c35d7cdd4 Mon Sep 17 00:00:00 2001 From: CREDO23 Date: Fri, 12 Jun 2026 18:53:08 +0200 Subject: [PATCH 091/212] fix(chunks): set position on remaining chunk insert paths document_converters, the github size-fallback chunker, revert_service restores, and the kb-persistence middleware now write explicit positions (the middleware read path also orders by position). --- .../middleware/kb_persistence/middleware.py | 26 +++++++++++++++---- .../app/services/revert_service.py | 22 ++++++++++++---- .../connector_indexers/github_indexer.py | 1 + .../app/utils/document_converters.py | 6 +++-- 4 files changed, 43 insertions(+), 12 deletions(-) diff --git a/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/middleware/kb_persistence/middleware.py b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/middleware/kb_persistence/middleware.py index ef86eaddd..a6c83a7d4 100644 --- a/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/middleware/kb_persistence/middleware.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/middleware/kb_persistence/middleware.py @@ -241,8 +241,15 @@ async def _create_document( chunk_embeddings = await asyncio.to_thread(embed_texts, chunks) session.add_all( [ - Chunk(document_id=doc.id, content=text, embedding=embedding) - for text, embedding in zip(chunks, chunk_embeddings, strict=True) + Chunk( + document_id=doc.id, + content=text, + embedding=embedding, + position=i, + ) + for i, (text, embedding) in enumerate( + zip(chunks, chunk_embeddings, strict=True) + ) ] ) return doc @@ -289,8 +296,15 @@ async def _update_document( chunk_embeddings = await asyncio.to_thread(embed_texts, chunks) session.add_all( [ - Chunk(document_id=document.id, content=text, embedding=embedding) - for text, embedding in zip(chunks, chunk_embeddings, strict=True) + Chunk( + document_id=document.id, + content=text, + embedding=embedding, + position=i, + ) + for i, (text, embedding) in enumerate( + zip(chunks, chunk_embeddings, strict=True) + ) ] ) return document @@ -475,7 +489,9 @@ async def _load_chunks_for_snapshot( session: AsyncSession, *, doc_id: int ) -> list[dict[str, str]]: rows = await session.execute( - select(Chunk.content).where(Chunk.document_id == doc_id).order_by(Chunk.id) + select(Chunk.content) + .where(Chunk.document_id == doc_id) + .order_by(Chunk.position, Chunk.id) ) return [{"content": row.content} for row in rows.all() if row.content is not None] diff --git a/surfsense_backend/app/services/revert_service.py b/surfsense_backend/app/services/revert_service.py index 6db5e2604..0cb6cd092 100644 --- a/surfsense_backend/app/services/revert_service.py +++ b/surfsense_backend/app/services/revert_service.py @@ -238,9 +238,14 @@ async def _restore_in_place_document( chunk_embeddings = await asyncio.to_thread(embed_texts, chunk_texts) session.add_all( [ - Chunk(document_id=doc.id, content=text, embedding=embedding) - for text, embedding in zip( - chunk_texts, chunk_embeddings, strict=True + Chunk( + document_id=doc.id, + content=text, + embedding=embedding, + position=i, + ) + for i, (text, embedding) in enumerate( + zip(chunk_texts, chunk_embeddings, strict=True) ) ] ) @@ -336,8 +341,15 @@ async def _reinsert_document_from_revision( chunk_embeddings = await asyncio.to_thread(embed_texts, chunk_texts) session.add_all( [ - Chunk(document_id=new_doc.id, content=text, embedding=embedding) - for text, embedding in zip(chunk_texts, chunk_embeddings, strict=True) + Chunk( + document_id=new_doc.id, + content=text, + embedding=embedding, + position=i, + ) + for i, (text, embedding) in enumerate( + zip(chunk_texts, chunk_embeddings, strict=True) + ) ] ) diff --git a/surfsense_backend/app/tasks/connector_indexers/github_indexer.py b/surfsense_backend/app/tasks/connector_indexers/github_indexer.py index ce9b80e5e..557c2ce71 100644 --- a/surfsense_backend/app/tasks/connector_indexers/github_indexer.py +++ b/surfsense_backend/app/tasks/connector_indexers/github_indexer.py @@ -525,6 +525,7 @@ async def _simple_chunk_content(content: str, chunk_size: int = 4000) -> list: Chunk( content=chunk_text, embedding=embed_text(chunk_text), + position=len(chunks), ) ) diff --git a/surfsense_backend/app/utils/document_converters.py b/surfsense_backend/app/utils/document_converters.py index 694ae22ac..fef51d692 100644 --- a/surfsense_backend/app/utils/document_converters.py +++ b/surfsense_backend/app/utils/document_converters.py @@ -188,8 +188,10 @@ async def create_document_chunks(content: str) -> list[Chunk]: chunk_texts = [c.text for c in config.chunker_instance.chunk(content)] chunk_embeddings = await asyncio.to_thread(embed_texts, chunk_texts) return [ - Chunk(content=text, embedding=emb) - for text, emb in zip(chunk_texts, chunk_embeddings, strict=False) + Chunk(content=text, embedding=emb, position=i) + for i, (text, emb) in enumerate( + zip(chunk_texts, chunk_embeddings, strict=False) + ) ] From 052e9ef4d19b16e58bef8f7c22328f64b6f94fbc Mon Sep 17 00:00:00 2001 From: CREDO23 Date: Fri, 12 Jun 2026 18:53:21 +0200 Subject: [PATCH 092/212] refactor(chunks): order chunk reads by (document_id, position) Presentation and citation ordering moves off Chunk.id/created_at to the explicit position column (id kept as tiebreaker). Vector and ts_rank ranking order_by clauses are untouched. --- .../shared/middleware/filesystem/backends/kb_postgres.py | 4 ++-- .../shared/middleware/knowledge_search.py | 9 ++++++--- .../builtins/deliverables/tools/knowledge_base.py | 2 +- surfsense_backend/app/retriever/chunks_hybrid_search.py | 7 +++++-- .../app/retriever/documents_hybrid_search.py | 7 +++++-- surfsense_backend/app/routes/documents_routes.py | 8 ++++---- surfsense_backend/app/routes/editor_routes.py | 6 +++--- surfsense_backend/app/services/ai_file_sort_service.py | 2 +- surfsense_backend/app/services/export_service.py | 2 +- 9 files changed, 28 insertions(+), 19 deletions(-) diff --git a/surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/backends/kb_postgres.py b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/backends/kb_postgres.py index 7b8aaf2b0..e13196537 100644 --- a/surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/backends/kb_postgres.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/filesystem/backends/kb_postgres.py @@ -508,7 +508,7 @@ class KBPostgresBackend(BackendProtocol): chunk_rows = await session.execute( select(Chunk.id, Chunk.content) .where(Chunk.document_id == document.id) - .order_by(Chunk.id) + .order_by(Chunk.position, Chunk.id) ) chunks = [ {"chunk_id": row.id, "content": row.content} for row in chunk_rows.all() @@ -725,7 +725,7 @@ class KBPostgresBackend(BackendProtocol): .join(Document, Document.id == Chunk.document_id) .where(Document.search_space_id == self.search_space_id) .where(Chunk.content.ilike(f"%{pattern}%")) - .order_by(Chunk.document_id, Chunk.id) + .order_by(Chunk.document_id, Chunk.position, Chunk.id) ) chunk_rows = await session.execute(sub) per_doc: dict[int, int] = {} diff --git a/surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/knowledge_search.py b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/knowledge_search.py index 681e80b0e..9ef601791 100644 --- a/surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/knowledge_search.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/shared/middleware/knowledge_search.py @@ -394,7 +394,10 @@ async def browse_recent_documents( Chunk.document_id, Chunk.content, func.row_number() - .over(partition_by=Chunk.document_id, order_by=Chunk.id) + .over( + partition_by=Chunk.document_id, + order_by=(Chunk.position, Chunk.id), + ) .label("rn"), ) .where(Chunk.document_id.in_(doc_ids)) @@ -404,7 +407,7 @@ async def browse_recent_documents( chunk_query = ( select(numbered.c.chunk_id, numbered.c.document_id, numbered.c.content) .where(numbered.c.rn <= _RECENCY_MAX_CHUNKS_PER_DOC) - .order_by(numbered.c.document_id, numbered.c.chunk_id) + .order_by(numbered.c.document_id, numbered.c.rn) ) chunk_result = await session.execute(chunk_query) fetched_chunks = chunk_result.all() @@ -531,7 +534,7 @@ async def fetch_mentioned_documents( chunk_result = await session.execute( select(Chunk.id, Chunk.content, Chunk.document_id) .where(Chunk.document_id.in_(list(docs.keys()))) - .order_by(Chunk.document_id, Chunk.id) + .order_by(Chunk.document_id, Chunk.position, Chunk.id) ) chunks_by_doc: dict[int, list[dict[str, Any]]] = {doc_id: [] for doc_id in docs} for row in chunk_result.all(): diff --git a/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/deliverables/tools/knowledge_base.py b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/deliverables/tools/knowledge_base.py index e99e0291a..d89124990 100644 --- a/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/deliverables/tools/knowledge_base.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/deliverables/tools/knowledge_base.py @@ -122,7 +122,7 @@ async def _browse_recent_documents( chunk_query = ( select(Chunk) .where(Chunk.document_id.in_(doc_ids)) - .order_by(Chunk.document_id, Chunk.id) + .order_by(Chunk.document_id, Chunk.position, Chunk.id) ) chunk_result = await session.execute(chunk_query) raw_chunks = chunk_result.scalars().all() diff --git a/surfsense_backend/app/retriever/chunks_hybrid_search.py b/surfsense_backend/app/retriever/chunks_hybrid_search.py index 47f7fe6b1..5e5edec2e 100644 --- a/surfsense_backend/app/retriever/chunks_hybrid_search.py +++ b/surfsense_backend/app/retriever/chunks_hybrid_search.py @@ -420,7 +420,10 @@ class ChucksHybridSearchRetriever: select( Chunk.id.label("chunk_id"), func.row_number() - .over(partition_by=Chunk.document_id, order_by=Chunk.id) + .over( + partition_by=Chunk.document_id, + order_by=(Chunk.position, Chunk.id), + ) .label("rn"), ) .where(Chunk.document_id.in_(doc_ids)) @@ -441,7 +444,7 @@ class ChucksHybridSearchRetriever: select(Chunk.id, Chunk.content, Chunk.document_id) .join(numbered, Chunk.id == numbered.c.chunk_id) .where(chunk_filter) - .order_by(Chunk.document_id, Chunk.id) + .order_by(Chunk.document_id, Chunk.position, Chunk.id) ) t_fetch = time.perf_counter() diff --git a/surfsense_backend/app/retriever/documents_hybrid_search.py b/surfsense_backend/app/retriever/documents_hybrid_search.py index 9ce86d404..d856e93cf 100644 --- a/surfsense_backend/app/retriever/documents_hybrid_search.py +++ b/surfsense_backend/app/retriever/documents_hybrid_search.py @@ -357,7 +357,10 @@ class DocumentHybridSearchRetriever: select( Chunk.id.label("chunk_id"), func.row_number() - .over(partition_by=Chunk.document_id, order_by=Chunk.id) + .over( + partition_by=Chunk.document_id, + order_by=(Chunk.position, Chunk.id), + ) .label("rn"), ) .where(Chunk.document_id.in_(doc_ids)) @@ -369,7 +372,7 @@ class DocumentHybridSearchRetriever: select(Chunk.id, Chunk.content, Chunk.document_id) .join(numbered, Chunk.id == numbered.c.chunk_id) .where(numbered.c.rn <= _MAX_FETCH_CHUNKS_PER_DOC) - .order_by(Chunk.document_id, Chunk.id) + .order_by(Chunk.document_id, Chunk.position, Chunk.id) ) t_fetch = time.perf_counter() diff --git a/surfsense_backend/app/routes/documents_routes.py b/surfsense_backend/app/routes/documents_routes.py index 865068fba..53f03a0ca 100644 --- a/surfsense_backend/app/routes/documents_routes.py +++ b/surfsense_backend/app/routes/documents_routes.py @@ -1014,8 +1014,8 @@ async def get_document_by_chunk_id( .filter( Chunk.document_id == document.id, or_( - Chunk.created_at < chunk.created_at, - and_(Chunk.created_at == chunk.created_at, Chunk.id < chunk.id), + Chunk.position < chunk.position, + and_(Chunk.position == chunk.position, Chunk.id < chunk.id), ), ) ) @@ -1027,7 +1027,7 @@ async def get_document_by_chunk_id( windowed_result = await session.execute( select(Chunk) .filter(Chunk.document_id == document.id) - .order_by(Chunk.created_at, Chunk.id) + .order_by(Chunk.position, Chunk.id) .offset(start) .limit(end - start) ) @@ -1137,7 +1137,7 @@ async def get_document_chunks_paginated( chunks_result = await session.execute( select(Chunk) .filter(Chunk.document_id == document_id) - .order_by(Chunk.created_at, Chunk.id) + .order_by(Chunk.position, Chunk.id) .offset(offset) .limit(page_size) ) diff --git a/surfsense_backend/app/routes/editor_routes.py b/surfsense_backend/app/routes/editor_routes.py index 166164c50..34828964a 100644 --- a/surfsense_backend/app/routes/editor_routes.py +++ b/surfsense_backend/app/routes/editor_routes.py @@ -119,7 +119,7 @@ async def get_editor_content( chunk_contents_result = await session.execute( select(Chunk.content) .filter(Chunk.document_id == document_id) - .order_by(Chunk.id) + .order_by(Chunk.position, Chunk.id) ) chunk_contents = chunk_contents_result.scalars().all() @@ -205,7 +205,7 @@ async def download_document_markdown( chunk_contents_result = await session.execute( select(Chunk.content) .filter(Chunk.document_id == document_id) - .order_by(Chunk.id) + .order_by(Chunk.position, Chunk.id) ) chunk_contents = chunk_contents_result.scalars().all() if chunk_contents: @@ -354,7 +354,7 @@ async def export_document( chunk_contents_result = await session.execute( select(Chunk.content) .filter(Chunk.document_id == document_id) - .order_by(Chunk.id) + .order_by(Chunk.position, Chunk.id) ) chunk_contents = chunk_contents_result.scalars().all() if chunk_contents: diff --git a/surfsense_backend/app/services/ai_file_sort_service.py b/surfsense_backend/app/services/ai_file_sort_service.py index 2f04131a6..1bf4d325e 100644 --- a/surfsense_backend/app/services/ai_file_sort_service.py +++ b/surfsense_backend/app/services/ai_file_sort_service.py @@ -156,7 +156,7 @@ async def _resolve_document_text( stmt = ( select(Chunk.content) .where(Chunk.document_id == document.id) - .order_by(Chunk.id) + .order_by(Chunk.position, Chunk.id) .limit(_MAX_CHUNKS_FOR_CONTEXT) ) result = await session.execute(stmt) diff --git a/surfsense_backend/app/services/export_service.py b/surfsense_backend/app/services/export_service.py index 97f952223..9e6869fe1 100644 --- a/surfsense_backend/app/services/export_service.py +++ b/surfsense_backend/app/services/export_service.py @@ -62,7 +62,7 @@ async def _get_document_markdown( chunk_result = await session.execute( select(Chunk.content) .filter(Chunk.document_id == document.id) - .order_by(Chunk.id) + .order_by(Chunk.position, Chunk.id) ) chunks = chunk_result.scalars().all() if chunks: From 311570b4f0e097914e816630b87c1e0597418b93 Mon Sep 17 00:00:00 2001 From: CREDO23 Date: Fri, 12 Jun 2026 18:53:21 +0200 Subject: [PATCH 093/212] test(indexing): cover the edit path and make integration caches hermetic Real-DB tests assert unchanged chunk rows survive edits, only new text is embedded, removed rows are deleted with positions compacted, and the kill switch restores full-replace. An autouse fixture disables the ETL/embedding caches so a developer's .env can't leak cache hits into unrelated tests. --- .../tests/integration/conftest.py | 13 ++ .../indexing_pipeline/test_index_editions.py | 193 ++++++++++++++++++ 2 files changed, 206 insertions(+) create mode 100644 surfsense_backend/tests/integration/indexing_pipeline/test_index_editions.py diff --git a/surfsense_backend/tests/integration/conftest.py b/surfsense_backend/tests/integration/conftest.py index 8457047ec..6b8aa3cdb 100644 --- a/surfsense_backend/tests/integration/conftest.py +++ b/surfsense_backend/tests/integration/conftest.py @@ -123,6 +123,19 @@ async def db_search_space(db_session: AsyncSession, db_user: User) -> SearchSpac return space +@pytest.fixture(autouse=True) +def _derivation_caches_disabled(monkeypatch): + """Keep integration tests hermetic regardless of the developer's .env. + + With the embedding cache enabled, a successful index of some markdown makes + every later index of the same markdown a cache hit -- silently bypassing + patched ``embed_texts`` fakes/failure injections in unrelated tests. Cache + tests opt back in explicitly via ``monkeypatch.setattr``. + """ + monkeypatch.setattr(app_config, "ETL_CACHE_ENABLED", False) + monkeypatch.setattr(app_config, "EMBEDDING_CACHE_ENABLED", False) + + @pytest.fixture def patched_embed_texts(monkeypatch) -> MagicMock: mock = MagicMock(side_effect=lambda texts: [[0.1] * _EMBEDDING_DIM for _ in texts]) diff --git a/surfsense_backend/tests/integration/indexing_pipeline/test_index_editions.py b/surfsense_backend/tests/integration/indexing_pipeline/test_index_editions.py new file mode 100644 index 000000000..68d5ec0af --- /dev/null +++ b/surfsense_backend/tests/integration/indexing_pipeline/test_index_editions.py @@ -0,0 +1,193 @@ +"""Edit path: re-indexing a document diffs chunks instead of replacing them. + +Unchanged paragraphs must keep their chunk rows (ids survive -> embeddings and +HNSW entries untouched), only new text is embedded, removed text is deleted, +and (position) keeps presentation order correct throughout. +""" + +import pytest +from sqlalchemy import select + +from app.db import Chunk, DocumentStatus +from app.indexing_pipeline.indexing_pipeline_service import IndexingPipelineService + +pytestmark = pytest.mark.integration + +_V1 = "Intro paragraph.\n\nBody paragraph.\n\nOutro paragraph." + + +@pytest.fixture +def paragraph_chunker(monkeypatch): + """One chunk per markdown paragraph, so edits map to chunk-level diffs.""" + + def _split(markdown, **_kwargs): + return [p for p in markdown.split("\n\n") if p.strip()] + + monkeypatch.setattr( + "app.indexing_pipeline.cache.cached_indexing.chunk_text", _split + ) + monkeypatch.setattr( + "app.indexing_pipeline.cache.cached_indexing.chunk_text_hybrid", _split + ) + + +async def _index(service, connector_doc): + prepared = await service.prepare_for_indexing([connector_doc]) + document = prepared[0] + await service.index(document, connector_doc) + return document + + +async def _load_chunks(db_session, document_id): + result = await db_session.execute( + select(Chunk) + .where(Chunk.document_id == document_id) + .order_by(Chunk.position, Chunk.id) + ) + return result.scalars().all() + + +@pytest.mark.usefixtures("paragraph_chunker") +async def test_edit_keeps_unchanged_rows_and_embeds_only_the_new_text( + db_session, + db_search_space, + make_connector_document, + patched_embed_texts, +): + service = IndexingPipelineService(session=db_session) + doc_v1 = make_connector_document( + search_space_id=db_search_space.id, source_markdown=_V1 + ) + document = await _index(service, doc_v1) + + ids_v1 = {c.content: c.id for c in await _load_chunks(db_session, document.id)} + patched_embed_texts.reset_mock() + + edited = "Intro paragraph.\n\nBody paragraph EDITED.\n\nOutro paragraph." + doc_v2 = make_connector_document( + search_space_id=db_search_space.id, source_markdown=edited + ) + await _index(service, doc_v2) + + chunks = await _load_chunks(db_session, document.id) + by_content = {c.content: c for c in chunks} + + # Untouched paragraphs keep their rows (same ids => embeddings reused, + # no HNSW/GIN churn); the edited paragraph got a fresh row. + assert by_content["Intro paragraph."].id == ids_v1["Intro paragraph."] + assert by_content["Outro paragraph."].id == ids_v1["Outro paragraph."] + assert "Body paragraph." not in by_content + assert by_content["Body paragraph EDITED."].id not in ids_v1.values() + + # Exactly one embed call: the document summary plus only the edited text. + (embedded_texts,) = patched_embed_texts.call_args.args + assert embedded_texts == [edited, "Body paragraph EDITED."] + + assert [c.position for c in chunks] == [0, 1, 2] + assert [c.content for c in chunks] == [ + "Intro paragraph.", + "Body paragraph EDITED.", + "Outro paragraph.", + ] + + +@pytest.mark.usefixtures("paragraph_chunker", "patched_embed_texts") +async def test_head_insert_shifts_positions_without_new_rows_for_old_text( + db_session, + db_search_space, + make_connector_document, +): + service = IndexingPipelineService(session=db_session) + document = await _index( + service, + make_connector_document( + search_space_id=db_search_space.id, source_markdown=_V1 + ), + ) + ids_v1 = {c.content: c.id for c in await _load_chunks(db_session, document.id)} + + await _index( + service, + make_connector_document( + search_space_id=db_search_space.id, + source_markdown="Brand new opener.\n\n" + _V1, + ), + ) + + chunks = await _load_chunks(db_session, document.id) + assert [c.content for c in chunks] == [ + "Brand new opener.", + "Intro paragraph.", + "Body paragraph.", + "Outro paragraph.", + ] + assert [c.position for c in chunks] == [0, 1, 2, 3] + # The three original rows survived the shift. + surviving = {c.content: c.id for c in chunks if c.content in ids_v1} + assert surviving == ids_v1 + + +@pytest.mark.usefixtures("paragraph_chunker", "patched_embed_texts") +async def test_removed_paragraph_is_deleted_and_order_compacts( + db_session, + db_search_space, + make_connector_document, +): + service = IndexingPipelineService(session=db_session) + document = await _index( + service, + make_connector_document( + search_space_id=db_search_space.id, source_markdown=_V1 + ), + ) + ids_v1 = {c.content: c.id for c in await _load_chunks(db_session, document.id)} + + await _index( + service, + make_connector_document( + search_space_id=db_search_space.id, + source_markdown="Intro paragraph.\n\nOutro paragraph.", + ), + ) + + chunks = await _load_chunks(db_session, document.id) + assert [(c.content, c.position) for c in chunks] == [ + ("Intro paragraph.", 0), + ("Outro paragraph.", 1), + ] + assert chunks[0].id == ids_v1["Intro paragraph."] + assert chunks[1].id == ids_v1["Outro paragraph."] + + +@pytest.mark.usefixtures("paragraph_chunker", "patched_embed_texts") +async def test_kill_switch_falls_back_to_full_replace( + db_session, + db_search_space, + make_connector_document, + monkeypatch, +): + from app.config import config + + service = IndexingPipelineService(session=db_session) + document = await _index( + service, + make_connector_document( + search_space_id=db_search_space.id, source_markdown=_V1 + ), + ) + ids_v1 = {c.id for c in await _load_chunks(db_session, document.id)} + + monkeypatch.setattr(config, "CHUNK_RECONCILE_ENABLED", False) + await _index( + service, + make_connector_document( + search_space_id=db_search_space.id, + source_markdown=_V1 + "\n\nAppended paragraph.", + ), + ) + + chunks = await _load_chunks(db_session, document.id) + # Legacy behavior: every row is recreated, even unchanged paragraphs. + assert {c.id for c in chunks}.isdisjoint(ids_v1) + assert [c.position for c in chunks] == [0, 1, 2, 3] + assert DocumentStatus.is_state(document.status, DocumentStatus.READY) From 407f2a9612acbab423a5b490e1caa079158a3ecb Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Fri, 12 Jun 2026 22:41:21 +0530 Subject: [PATCH 094/212] feat(model-connections): enhance model connection functionality with preview and selection features --- .../app/routes/model_connections_routes.py | 84 +++++++- surfsense_backend/app/schemas/__init__.py | 2 + .../app/schemas/model_connections.py | 27 +++ .../app/services/model_connection_service.py | 45 ++++ .../app/services/provider_registry.py | 3 +- .../model-connections-mutation.atoms.ts | 11 + .../agent-action-log/action-log-dialog.tsx | 4 +- .../settings/model-connections-settings.tsx | 192 ++++++++++-------- .../model-connections/azure-connect-form.tsx | 65 +++--- .../bedrock-connect-form.tsx | 144 ++++++------- .../model-connections/connect-fields.tsx | 32 ++- .../connection-settings-dialog.tsx | 20 +- .../default-connect-form.tsx | 48 ++--- .../settings/model-connections/model-utils.ts | 13 +- .../models-selection-panel.tsx | 10 +- .../provider-connect-dialog.tsx | 171 ++++++++-------- .../model-connections/provider-metadata.tsx | 4 +- .../model-connections/vertex-connect-form.tsx | 149 +++++++------- .../types/model-connections.types.ts | 19 ++ .../lib/apis/model-connections-api.service.ts | 16 ++ 20 files changed, 630 insertions(+), 429 deletions(-) diff --git a/surfsense_backend/app/routes/model_connections_routes.py b/surfsense_backend/app/routes/model_connections_routes.py index 730c68565..2405843a7 100644 --- a/surfsense_backend/app/routes/model_connections_routes.py +++ b/surfsense_backend/app/routes/model_connections_routes.py @@ -21,10 +21,12 @@ from app.schemas import ( ConnectionRead, ConnectionUpdate, ModelCreate, + ModelPreviewRead, ModelProviderRead, ModelRead, ModelRolesRead, ModelRolesUpdate, + ModelSelection, ModelsBulkUpdate, ModelUpdate, VerifyConnectionResponse, @@ -48,6 +50,21 @@ def _model_read(model: Model | dict) -> ModelRead: return ModelRead.model_validate(model) +def _preview_model_read(item: dict) -> ModelPreviewRead: + return ModelPreviewRead( + model_id=item["model_id"], + display_name=item.get("display_name"), + source=item.get("source", ModelSource.DISCOVERED), + supports_chat=item.get("supports_chat"), + max_input_tokens=item.get("max_input_tokens"), + supports_image_input=item.get("supports_image_input"), + supports_tools=item.get("supports_tools"), + supports_image_generation=item.get("supports_image_generation"), + enabled=item.get("enabled", False), + metadata=item.get("metadata") or item.get("catalog") or {}, + ) + + def _connection_read(conn: Connection | dict, models: list[Model | dict] | None = None) -> ConnectionRead: if isinstance(conn, dict): payload = { @@ -86,6 +103,25 @@ def _apply_model_facts(model: Model, facts: dict) -> None: model.supports_image_generation = facts.get("supports_image_generation") +def _selection_to_model(conn: Connection, selection: ModelSelection) -> Model: + source = ( + selection.source + if isinstance(selection.source, ModelSource) + else ModelSource(selection.source) + ) + model = Model( + connection_id=conn.id, + model_id=selection.model_id.strip(), + display_name=selection.display_name, + source=source, + capabilities_override={}, + enabled=selection.enabled, + catalog=selection.metadata, + ) + _apply_model_facts(model, selection.model_dump()) + return model + + def _default_model_for(models: list[Model], capability: str) -> int | None: for model in models: if model.enabled and has_capability(model, capability): @@ -226,7 +262,7 @@ async def create_connection( Permission.LLM_CONFIGS_CREATE.value, "You don't have permission to create model connections in this search space", ) - payload = data.model_dump(exclude={"search_space_id"}) + payload = data.model_dump(exclude={"search_space_id", "models"}) conn = Connection( **payload, @@ -234,9 +270,51 @@ async def create_connection( user_id=user.id, ) session.add(conn) + await session.flush() + + seen_model_ids: set[str] = set() + for selection in data.models: + model_id = selection.model_id.strip() + if not model_id or model_id in seen_model_ids: + continue + seen_model_ids.add(model_id) + session.add(_selection_to_model(conn, selection)) + await session.commit() - await session.refresh(conn) - return _connection_read(conn, []) + conn = await _load_connection(session, conn.id) + await _default_unset_roles(session, conn, list(conn.models)) + await session.commit() + conn = await _load_connection(session, conn.id) + return _connection_read(conn, list(conn.models)) + + +@router.post("/model-connections/discover-preview", response_model=list[ModelPreviewRead]) +async def preview_connection_models( + data: ConnectionCreate, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +): + if data.scope == ConnectionScope.SEARCH_SPACE and data.search_space_id is not None: + await check_permission( + session, + user, + data.search_space_id, + Permission.LLM_CONFIGS_CREATE.value, + "You don't have permission to create model connections in this search space", + ) + + draft = Connection( + provider=data.provider, + base_url=data.base_url, + api_key=data.api_key, + extra=data.extra or {}, + scope=data.scope, + enabled=data.enabled, + search_space_id=data.search_space_id if data.scope == ConnectionScope.SEARCH_SPACE else None, + user_id=user.id, + ) + discovered = await discover_models(draft) + return [_preview_model_read(item) for item in discovered] @router.put("/model-connections/{connection_id}", response_model=ConnectionRead) diff --git a/surfsense_backend/app/schemas/__init__.py b/surfsense_backend/app/schemas/__init__.py index 55e712f12..efa448dcd 100644 --- a/surfsense_backend/app/schemas/__init__.py +++ b/surfsense_backend/app/schemas/__init__.py @@ -49,10 +49,12 @@ from .model_connections import ( ConnectionRead, ConnectionUpdate, ModelCreate, + ModelPreviewRead, ModelProviderRead, ModelRead, ModelRolesRead, ModelRolesUpdate, + ModelSelection, ModelsBulkUpdate, ModelUpdate, VerifyConnectionResponse, diff --git a/surfsense_backend/app/schemas/model_connections.py b/surfsense_backend/app/schemas/model_connections.py index 0b03c7fab..896532d6f 100644 --- a/surfsense_backend/app/schemas/model_connections.py +++ b/surfsense_backend/app/schemas/model_connections.py @@ -48,6 +48,32 @@ class ConnectionRead(BaseModel): model_config = ConfigDict(from_attributes=True) +class ModelSelection(BaseModel): + model_id: str = Field(..., max_length=255) + display_name: str | None = Field(None, max_length=255) + source: ModelSource | str = ModelSource.DISCOVERED + supports_chat: bool | None = None + max_input_tokens: int | None = None + supports_image_input: bool | None = None + supports_tools: bool | None = None + supports_image_generation: bool | None = None + enabled: bool = False + metadata: dict[str, Any] = Field(default_factory=dict) + + +class ModelPreviewRead(BaseModel): + model_id: str + display_name: str | None = None + source: ModelSource | str = ModelSource.DISCOVERED + supports_chat: bool | None = None + max_input_tokens: int | None = None + supports_image_input: bool | None = None + supports_tools: bool | None = None + supports_image_generation: bool | None = None + enabled: bool = False + metadata: dict[str, Any] = Field(default_factory=dict) + + class ConnectionCreate(BaseModel): provider: str = Field(..., max_length=100) base_url: str | None = Field(None, max_length=500) @@ -56,6 +82,7 @@ class ConnectionCreate(BaseModel): scope: ConnectionScope = ConnectionScope.SEARCH_SPACE search_space_id: int | None = None enabled: bool = True + models: list[ModelSelection] = Field(default_factory=list) class ConnectionUpdate(BaseModel): diff --git a/surfsense_backend/app/services/model_connection_service.py b/surfsense_backend/app/services/model_connection_service.py index 428af736e..7742e837e 100644 --- a/surfsense_backend/app/services/model_connection_service.py +++ b/surfsense_backend/app/services/model_connection_service.py @@ -8,6 +8,7 @@ from dataclasses import dataclass from datetime import UTC, datetime from typing import Any +import anyio import httpx import litellm @@ -292,6 +293,48 @@ def _litellm_static_models(conn: Connection) -> list[dict[str, Any]]: return results +async def _discover_bedrock_models(conn: Connection) -> list[dict[str, Any]]: + params = (conn.extra or {}).get("litellm_params", {}) + region_name = params.get("aws_region_name") + if not region_name: + return [] + + def list_models() -> list[dict[str, Any]]: + import boto3 + + client_kwargs: dict[str, str] = {"region_name": region_name} + if params.get("aws_access_key_id"): + client_kwargs["aws_access_key_id"] = params["aws_access_key_id"] + if params.get("aws_secret_access_key"): + client_kwargs["aws_secret_access_key"] = params["aws_secret_access_key"] + + client = boto3.client("bedrock", **client_kwargs) + response = client.list_foundation_models() + results: list[dict[str, Any]] = [] + for item in response.get("modelSummaries", []): + model_id = item.get("modelId") + if not model_id: + continue + input_modalities = set(item.get("inputModalities") or []) + output_modalities = set(item.get("outputModalities") or []) + results.append( + { + "model_id": model_id, + "display_name": item.get("modelName") or model_id, + "source": ModelSource.DISCOVERED, + "supports_chat": "TEXT" in input_modalities and "TEXT" in output_modalities, + "supports_image_input": "IMAGE" in input_modalities, + "supports_tools": None, + "supports_image_generation": "IMAGE" in output_modalities, + "max_input_tokens": None, + "metadata": item, + } + ) + return results + + return await anyio.to_thread.run_sync(list_models) + + async def discover_models(conn: Connection) -> list[dict[str, Any]]: allowlist = _allowlist(conn) spec = spec_for(conn.provider) @@ -304,6 +347,8 @@ async def discover_models(conn: Connection) -> list[dict[str, Any]]: results = await _discover_anthropic_models(conn) elif spec.discovery == "openai_models": results = await _discover_openai_shaped_models(conn, conn.base_url) + elif spec.discovery == "bedrock_models": + results = await _discover_bedrock_models(conn) elif spec.discovery == "static": results = _litellm_static_models(conn) else: diff --git a/surfsense_backend/app/services/provider_registry.py b/surfsense_backend/app/services/provider_registry.py index 871769f11..98bfb63c1 100644 --- a/surfsense_backend/app/services/provider_registry.py +++ b/surfsense_backend/app/services/provider_registry.py @@ -21,6 +21,7 @@ DiscoveryKind = Literal[ "ollama", "openai_models", "anthropic_models", + "bedrock_models", "openrouter", "static", "none", @@ -51,7 +52,7 @@ REGISTRY: dict[str, ProviderSpec] = { Transport.NATIVE, "vertex_ai", "static", None, False, "native" ), "bedrock": ProviderSpec( - Transport.NATIVE, "bedrock", "static", None, False, "native" + Transport.NATIVE, "bedrock", "bedrock_models", None, False, "native" ), "openrouter": ProviderSpec( Transport.OPENAI_COMPATIBLE, diff --git a/surfsense_web/atoms/model-connections/model-connections-mutation.atoms.ts b/surfsense_web/atoms/model-connections/model-connections-mutation.atoms.ts index fee3b95ba..ea91c6483 100644 --- a/surfsense_web/atoms/model-connections/model-connections-mutation.atoms.ts +++ b/surfsense_web/atoms/model-connections/model-connections-mutation.atoms.ts @@ -4,6 +4,7 @@ import type { ConnectionCreateRequest, ConnectionUpdateRequest, ModelCreateRequest, + ModelPreviewRead, ModelRead, ModelRoles, ModelsBulkUpdateRequest, @@ -103,6 +104,16 @@ export const discoverConnectionModelsMutationAtom = atomWithMutation((get) => { }; }); +export const previewConnectionModelsMutationAtom = atomWithMutation(() => { + return { + mutationKey: ["model-connections", "discover-preview"], + mutationFn: (request: ConnectionCreateRequest) => + modelConnectionsApiService.previewModels(request), + onSuccess: (_models: ModelPreviewRead[]) => {}, + onError: (error: Error) => toast.error(error.message || "Failed to discover models"), + }; +}); + export const addManualModelMutationAtom = atomWithMutation((get) => { const searchSpaceId = Number(get(activeSearchSpaceIdAtom)); return { diff --git a/surfsense_web/components/agent-action-log/action-log-dialog.tsx b/surfsense_web/components/agent-action-log/action-log-dialog.tsx index 1d0eefc17..5f3b83db1 100644 --- a/surfsense_web/components/agent-action-log/action-log-dialog.tsx +++ b/surfsense_web/components/agent-action-log/action-log-dialog.tsx @@ -2,7 +2,7 @@ import { useQueryClient } from "@tanstack/react-query"; import { useAtom, useAtomValue } from "jotai"; -import { RefreshCcw, Workflow } from "lucide-react"; +import { RefreshCw, Workflow } from "lucide-react"; import { useCallback } from "react"; import { actionLogDialogAtom } from "@/atoms/agent/action-log-dialog.atom"; import { agentFlagsAtom } from "@/atoms/agent/agent-flags-query.atom"; @@ -112,7 +112,7 @@ export function ActionLogDialog() { className="absolute right-14 top-4 size-8 rounded-full p-0 text-muted-foreground hover:bg-accent hover:text-accent-foreground" aria-label="Refresh action log" > - +
diff --git a/surfsense_web/components/settings/model-connections-settings.tsx b/surfsense_web/components/settings/model-connections-settings.tsx index 2e15ce2e9..6c3d1a411 100644 --- a/surfsense_web/components/settings/model-connections-settings.tsx +++ b/surfsense_web/components/settings/model-connections-settings.tsx @@ -4,12 +4,9 @@ import { useAtom, useAtomValue } from "jotai"; import { CheckCircle2, Trash2, XCircle } from "lucide-react"; import { useState } from "react"; import { - addManualModelMutationAtom, - bulkUpdateModelsMutationAtom, createModelConnectionMutationAtom, deleteModelConnectionMutationAtom, - discoverConnectionModelsMutationAtom, - updateModelMutationAtom, + previewConnectionModelsMutationAtom, updateModelRolesMutationAtom, } from "@/atoms/model-connections/model-connections-mutation.atoms"; import { @@ -41,9 +38,13 @@ import { SelectValue, } from "@/components/ui/select"; import { Separator } from "@/components/ui/separator"; -import type { ConnectionRead, ModelRead } from "@/contracts/types/model-connections.types"; +import type { + ConnectionRead, + ModelRead, + ModelSelection, +} from "@/contracts/types/model-connections.types"; import { ConnectionSettingsDialog } from "./model-connections/connection-settings-dialog"; -import { capability, modelLabel } from "./model-connections/model-utils"; +import { capability, modelLabel, type SelectableModel } from "./model-connections/model-utils"; import { ProviderConnectDialog } from "./model-connections/provider-connect-dialog"; import { type ConnectionDraft, @@ -154,16 +155,12 @@ export function ModelConnectionsSettings({ searchSpaceId }: { searchSpaceId: num const [{ data: providers = [] }] = useAtom(modelProvidersAtom); const [{ data: roles }] = useAtom(modelRolesAtom); const createConnection = useAtomValue(createModelConnectionMutationAtom); - const addManualModel = useAtomValue(addManualModelMutationAtom); - const discoverModels = useAtomValue(discoverConnectionModelsMutationAtom); - const updateModel = useAtomValue(updateModelMutationAtom); - const bulkUpdateModels = useAtomValue(bulkUpdateModelsMutationAtom); + const previewModels = useAtomValue(previewConnectionModelsMutationAtom); const updateRoles = useAtomValue(updateModelRolesMutationAtom); const [isAddProviderOpen, setIsAddProviderOpen] = useState(false); const [provider, setProvider] = useState("openai_compatible"); - const [connectedConnection, setConnectedConnection] = useState(null); - const [connectModels, setConnectModels] = useState([]); + const [connectModels, setConnectModels] = useState([]); const selectedProvider = providers.find((item) => item.provider === provider); const sortedProviders = [...providers].sort((left, right) => { @@ -185,7 +182,6 @@ export function ModelConnectionsSettings({ searchSpaceId }: { searchSpaceId: num const imageModels = enabledModels.filter((model) => capability(model, "image_gen")); function resetConnectState() { - setConnectedConnection(null); setConnectModels([]); } @@ -196,15 +192,48 @@ export function ModelConnectionsSettings({ searchSpaceId }: { searchSpaceId: num } } - function replaceConnectModels(updatedModels: ModelRead[]) { - setConnectModels((current) => - current.map((model) => updatedModels.find((updated) => updated.id === model.id) ?? model) - ); + function toModelSelection(model: SelectableModel): ModelSelection { + return { + model_id: model.model_id, + display_name: model.display_name, + source: model.source || "DISCOVERED", + supports_chat: model.supports_chat, + max_input_tokens: model.max_input_tokens, + supports_image_input: model.supports_image_input, + supports_tools: model.supports_tools, + supports_image_generation: model.supports_image_generation, + enabled: model.enabled, + metadata: "metadata" in model ? (model.metadata ?? {}) : (model.catalog ?? {}), + }; + } + + function mergePreviewModels(fetchedModels: SelectableModel[]) { + setConnectModels((current) => { + const currentById = new Map(current.map((model) => [model.model_id, model])); + return fetchedModels.map((model) => { + const prior = currentById.get(model.model_id); + return { + ...toModelSelection(model), + enabled: prior ? prior.enabled : model.enabled, + }; + }); + }); } // Each provider connect form builds its own credential payload; the backend // resolver (`to_litellm`) forwards `extra.litellm_params` straight to LiteLLM. function handleCreate(draft: ConnectionDraft) { + const models = [...connectModels]; + if (draft.seedModelId && !models.some((model) => model.model_id === draft.seedModelId)) { + models.push({ + model_id: draft.seedModelId, + display_name: draft.seedModelId, + source: "MANUAL", + enabled: true, + metadata: {}, + }); + } + createConnection.mutate( { provider, @@ -214,26 +243,12 @@ export function ModelConnectionsSettings({ searchSpaceId }: { searchSpaceId: num search_space_id: searchSpaceId, extra: draft.extra, enabled: true, + models, }, { - onSuccess: (created) => { - setConnectedConnection(created); - setConnectModels([]); - if (draft.seedModelId) { - addManualModel.mutate( - { - connectionId: created.id, - data: { model_id: draft.seedModelId }, - }, - { - onSuccess: (model) => setConnectModels([model]), - } - ); - } else { - discoverModels.mutate(created.id, { - onSuccess: (models) => setConnectModels(models), - }); - } + onSuccess: () => { + setIsAddProviderOpen(false); + resetConnectState(); }, } ); @@ -243,52 +258,72 @@ export function ModelConnectionsSettings({ searchSpaceId }: { searchSpaceId: num resetConnectState(); setProvider(providerId); setIsAddProviderOpen(true); + if (providerId === "vertex_ai") { + previewModels.mutate( + { + provider: providerId, + base_url: null, + api_key: null, + scope: "SEARCH_SPACE", + search_space_id: searchSpaceId, + extra: {}, + enabled: true, + models: [], + }, + { + onSuccess: mergePreviewModels, + } + ); + } } - function refreshConnectModels() { - if (!connectedConnection) return; - discoverModels.mutate(connectedConnection.id, { - onSuccess: (models) => setConnectModels(models), - }); + function refreshConnectModels(draft: ConnectionDraft) { + previewModels.mutate( + { + provider, + base_url: draft.base_url, + api_key: draft.api_key, + scope: "SEARCH_SPACE", + search_space_id: searchSpaceId, + extra: draft.extra, + enabled: true, + models: [], + }, + { + onSuccess: mergePreviewModels, + } + ); } function addConnectModel(modelId: string) { - if (!connectedConnection) return; - addManualModel.mutate( - { connectionId: connectedConnection.id, data: { model_id: modelId } }, - { - onSuccess: (model) => setConnectModels((current) => [...current, model]), - } + setConnectModels((current) => { + if (current.some((model) => model.model_id === modelId)) return current; + return [ + ...current, + { + model_id: modelId, + display_name: modelId, + source: "MANUAL", + enabled: true, + metadata: {}, + }, + ]; + }); + } + + function toggleConnectModel(model: SelectableModel, enabled: boolean) { + setConnectModels((current) => + current.map((item) => (item.model_id === model.model_id ? { ...item, enabled } : item)) ); } - function toggleConnectModel(model: ModelRead, enabled: boolean) { - updateModel.mutate( - { id: model.id, data: { enabled } }, - { - onSuccess: (updated) => replaceConnectModels([updated]), - } + function bulkToggleConnectModels(models: SelectableModel[], enabled: boolean) { + const modelIds = new Set(models.map((model) => model.model_id)); + setConnectModels((current) => + current.map((item) => (modelIds.has(item.model_id) ? { ...item, enabled } : item)) ); } - function bulkToggleConnectModels(models: ModelRead[], enabled: boolean) { - if (!connectedConnection) return; - bulkUpdateModels.mutate( - { - connectionId: connectedConnection.id, - data: { model_ids: models.map((model) => model.id), enabled }, - }, - { - onSuccess: replaceConnectModels, - } - ); - } - - function finishConnectFlow() { - setIsAddProviderOpen(false); - resetConnectState(); - } - function renderModelOption(model: ModelRead & { connectionName: string; provider: string }) { return ( @@ -347,17 +382,12 @@ export function ModelConnectionsSettings({ searchSpaceId }: { searchSpaceId: num selectedProvider={selectedProvider} isPending={createConnection.isPending} onSubmit={handleCreate} - connectedConnection={connectedConnection} - connectModels={connectModels} - isDiscoveringModels={discoverModels.isPending} - isAddingManualModel={addManualModel.isPending} - isUpdatingModel={updateModel.isPending} - isBulkUpdatingModels={bulkUpdateModels.isPending} - onRefreshModels={refreshConnectModels} - onAddManualModel={addConnectModel} - onToggleModel={toggleConnectModel} - onBulkToggleModels={bulkToggleConnectModels} - onDone={finishConnectFlow} + previewModels={connectModels} + isPreviewingModels={previewModels.isPending} + onPreviewModels={refreshConnectModels} + onAddPreviewModel={addConnectModel} + onTogglePreviewModel={toggleConnectModel} + onBulkTogglePreviewModels={bulkToggleConnectModels} /> {connections.length > 0 ? ( diff --git a/surfsense_web/components/settings/model-connections/azure-connect-form.tsx b/surfsense_web/components/settings/model-connections/azure-connect-form.tsx index 11a2e25d3..451f053db 100644 --- a/surfsense_web/components/settings/model-connections/azure-connect-form.tsx +++ b/surfsense_web/components/settings/model-connections/azure-connect-form.tsx @@ -1,7 +1,7 @@ -import { useState } from "react"; +import { useEffect, useState } from "react"; import { Input } from "@/components/ui/input"; import { Label } from "@/components/ui/label"; -import { ApiKeyField, ConnectFormFooter } from "./connect-fields"; +import { ApiKeyField } from "./connect-fields"; import { isValidAzureTargetUri, type ProviderConnectFormProps, @@ -12,48 +12,43 @@ import { * Azure OpenAI connect form. The user pastes a single Target URI, which we parse * into api base, api version, and the deployment name (seeded as the model). */ -export function AzureConnectForm({ isPending, onCancel, onSubmit }: ProviderConnectFormProps) { +export function AzureConnectForm({ onDraftChange }: ProviderConnectFormProps) { const [targetUri, setTargetUri] = useState(""); const [apiKey, setApiKey] = useState(""); const canSubmit = isValidAzureTargetUri(targetUri) && Boolean(apiKey.trim()); - function handleSubmit() { + useEffect(() => { const parsed = parseAzureTargetUri(targetUri); - onSubmit({ - base_url: parsed?.origin ?? null, - api_key: apiKey || null, - extra: parsed?.apiVersion ? { api_version: parsed.apiVersion } : {}, - seedModelId: parsed?.deploymentName || undefined, - }); - } + onDraftChange( + { + base_url: parsed?.origin ?? null, + api_key: apiKey || null, + extra: parsed?.apiVersion ? { api_version: parsed.apiVersion } : {}, + seedModelId: parsed?.deploymentName || undefined, + }, + canSubmit + ); + }, [apiKey, canSubmit, onDraftChange, targetUri]); return ( - <> -
-
- - setTargetUri(event.target.value)} - placeholder="https://your-resource.cognitiveservices.azure.com/openai/deployments/deployment-name/chat/completions?api-version=2025-01-01-preview" - /> -

- Paste your endpoint target URI from Azure OpenAI (including API base, deployment name, - and API version). -

-
- +
+ + setTargetUri(event.target.value)} + placeholder="https://your-resource.cognitiveservices.azure.com/openai/deployments/deployment-name/chat/completions?api-version=2025-01-01-preview" /> +

+ Paste your endpoint target URI from Azure OpenAI (including API base, deployment name, and + API version). +

- - +
); } diff --git a/surfsense_web/components/settings/model-connections/bedrock-connect-form.tsx b/surfsense_web/components/settings/model-connections/bedrock-connect-form.tsx index 9466c7cd1..3115ac223 100644 --- a/surfsense_web/components/settings/model-connections/bedrock-connect-form.tsx +++ b/surfsense_web/components/settings/model-connections/bedrock-connect-form.tsx @@ -1,4 +1,4 @@ -import { useState } from "react"; +import { useEffect, useState } from "react"; import { Input } from "@/components/ui/input"; import { Label } from "@/components/ui/label"; import { @@ -8,7 +8,7 @@ import { SelectTrigger, SelectValue, } from "@/components/ui/select"; -import { ConnectFormFooter } from "./connect-fields"; +import { ApiKeyField } from "./connect-fields"; import { AWS_REGION_OPTIONS, BEDROCK_AUTH_ACCESS_KEY, @@ -21,7 +21,7 @@ import { * Amazon Bedrock connect form. Region + auth method drive which AWS credentials * are collected; everything rides along in `extra.litellm_params`. */ -export function BedrockConnectForm({ isPending, onCancel, onSubmit }: ProviderConnectFormProps) { +export function BedrockConnectForm({ onDraftChange }: ProviderConnectFormProps) { const [region, setRegion] = useState(""); const [authMethod, setAuthMethod] = useState(BEDROCK_AUTH_ACCESS_KEY); const [accessKeyId, setAccessKeyId] = useState(""); @@ -39,7 +39,7 @@ export function BedrockConnectForm({ isPending, onCancel, onSubmit }: ProviderCo return true; })(); - function handleSubmit() { + useEffect(() => { const params: Record = { aws_region_name: region }; if (authMethod === BEDROCK_AUTH_ACCESS_KEY) { params.aws_access_key_id = accessKeyId; @@ -47,88 +47,74 @@ export function BedrockConnectForm({ isPending, onCancel, onSubmit }: ProviderCo } else if (authMethod === BEDROCK_AUTH_LONG_TERM_API_KEY) { params.aws_bearer_token_bedrock = bearerToken; } - onSubmit({ base_url: null, api_key: null, extra: { litellm_params: params } }); - } + onDraftChange({ base_url: null, api_key: null, extra: { litellm_params: params } }, canSubmit); + }, [accessKeyId, authMethod, bearerToken, canSubmit, onDraftChange, region, secretAccessKey]); return ( - <> -
-
- - -
-
- - -
- {authMethod === BEDROCK_AUTH_ACCESS_KEY ? ( - <> -
- - setAccessKeyId(event.target.value)} - placeholder="AKIAIOSFODNN7EXAMPLE" - /> -
-
- - setSecretAccessKey(event.target.value)} - placeholder="wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY" - type="password" - /> -
- - ) : null} - {authMethod === BEDROCK_AUTH_LONG_TERM_API_KEY ? ( +
+
+ + +
+
+ + +
+ {authMethod === BEDROCK_AUTH_ACCESS_KEY ? ( + <>
- + setBearerToken(event.target.value)} - placeholder="Your long-term API key" - type="password" + value={accessKeyId} + onChange={(event) => setAccessKeyId(event.target.value)} + placeholder="AKIAIOSFODNN7EXAMPLE" />
- ) : null} - {authMethod === BEDROCK_AUTH_IAM ? ( -

- SurfSense will use the IAM role attached to the environment it's running in to - authenticate. -

- ) : null} + + + ) : null} + {authMethod === BEDROCK_AUTH_LONG_TERM_API_KEY ? ( + + ) : null} + {authMethod === BEDROCK_AUTH_IAM ? (

- Add Bedrock model IDs from the provider's settings after connecting. + SurfSense will use the IAM role attached to the environment it's running in to + authenticate.

-
- - + ) : null} +

+ Add Bedrock model IDs from the provider's settings after connecting. +

+
); } diff --git a/surfsense_web/components/settings/model-connections/connect-fields.tsx b/surfsense_web/components/settings/model-connections/connect-fields.tsx index af8db7f12..44b2d434f 100644 --- a/surfsense_web/components/settings/model-connections/connect-fields.tsx +++ b/surfsense_web/components/settings/model-connections/connect-fields.tsx @@ -1,3 +1,5 @@ +import { Eye, EyeOff } from "lucide-react"; +import { useState } from "react"; import { Button } from "@/components/ui/button"; import { DialogFooter } from "@/components/ui/dialog"; import { Input } from "@/components/ui/input"; @@ -43,15 +45,31 @@ export function ApiKeyField({ label = "API Key", placeholder = "API key", }: ApiKeyFieldProps) { + const [showApiKey, setShowApiKey] = useState(false); + return (
- onChange(event.target.value)} - placeholder={placeholder} - type="password" - /> +
+ onChange(event.target.value)} + placeholder={placeholder} + type={showApiKey ? "text" : "password"} + className="pr-11" + /> + +
); } @@ -71,7 +89,7 @@ export function ConnectFormFooter({ isPending, }: ConnectFormFooterProps) { return ( - + diff --git a/surfsense_web/components/settings/model-connections/connection-settings-dialog.tsx b/surfsense_web/components/settings/model-connections/connection-settings-dialog.tsx index f3821af46..d0f8e6c16 100644 --- a/surfsense_web/components/settings/model-connections/connection-settings-dialog.tsx +++ b/surfsense_web/components/settings/model-connections/connection-settings-dialog.tsx @@ -25,8 +25,8 @@ import { Separator } from "@/components/ui/separator"; import type { ConnectionRead, ConnectionUpdateRequest, - ModelRead, } from "@/contracts/types/model-connections.types"; +import type { SelectableModel } from "./model-utils"; import { ModelsSelectionPanel } from "./models-selection-panel"; import { providerIcon } from "./provider-metadata"; @@ -101,17 +101,22 @@ export function ConnectionSettingsDialog({ }); } - function handleToggleModel(model: ModelRead, enabled: boolean) { + function handleToggleModel(model: SelectableModel, enabled: boolean) { + if (typeof model.id !== "number") return; updateModel.mutate({ id: model.id, data: { enabled }, }); } - function handleBulkToggle(models: ModelRead[], enabled: boolean) { + function handleBulkToggle(models: SelectableModel[], enabled: boolean) { + const modelIds = models + .map((model) => model.id) + .filter((id): id is number => typeof id === "number"); + if (modelIds.length === 0) return; bulkUpdateModels.mutate({ connectionId: connection.id, - data: { model_ids: models.map((model) => model.id), enabled }, + data: { model_ids: modelIds, enabled }, }); } @@ -184,12 +189,7 @@ export function ConnectionSettingsDialog({ onChange={(event) => setAllowlistText(event.target.value)} placeholder="Comma-separated, e.g. anthropic/claude-sonnet-4-5, google/gemini-2.5-pro" /> -
diff --git a/surfsense_web/components/settings/model-connections/default-connect-form.tsx b/surfsense_web/components/settings/model-connections/default-connect-form.tsx index 3f261c6b2..768c0b5da 100644 --- a/surfsense_web/components/settings/model-connections/default-connect-form.tsx +++ b/surfsense_web/components/settings/model-connections/default-connect-form.tsx @@ -1,5 +1,5 @@ -import { useState } from "react"; -import { ApiBaseUrlField, ApiKeyField, ConnectFormFooter } from "./connect-fields"; +import { useEffect, useState } from "react"; +import { ApiBaseUrlField, ApiKeyField } from "./connect-fields"; import type { ProviderConnectFormProps } from "./provider-metadata"; /** @@ -11,41 +11,31 @@ export function DefaultConnectForm({ provider, defaultBaseUrl, baseUrlRequired, - isPending, - onCancel, - onSubmit, + onDraftChange, }: ProviderConnectFormProps) { const [baseUrl, setBaseUrl] = useState(defaultBaseUrl); const [apiKey, setApiKey] = useState(""); const isOllama = provider === "ollama_chat"; const canSubmit = !(baseUrlRequired && !baseUrl.trim()); - function handleSubmit() { - onSubmit({ base_url: baseUrl || null, api_key: apiKey || null, extra: {} }); - } + useEffect(() => { + onDraftChange({ base_url: baseUrl || null, api_key: apiKey || null, extra: {} }, canSubmit); + }, [apiKey, baseUrl, canSubmit, onDraftChange]); return ( - <> -
- - -
- + - + +
); } diff --git a/surfsense_web/components/settings/model-connections/model-utils.ts b/surfsense_web/components/settings/model-connections/model-utils.ts index 1db14b3eb..2887f2179 100644 --- a/surfsense_web/components/settings/model-connections/model-utils.ts +++ b/surfsense_web/components/settings/model-connections/model-utils.ts @@ -1,4 +1,4 @@ -import type { ModelRead } from "@/contracts/types/model-connections.types"; +import type { ModelPreviewRead, ModelRead } from "@/contracts/types/model-connections.types"; export type ModelCapabilityFilter = "chat" | "vision" | "image_gen"; @@ -8,17 +8,22 @@ export const MODEL_CAPABILITY_FILTERS: { key: ModelCapabilityFilter; label: stri { key: "image_gen", label: "Image" }, ]; -export function modelLabel(model: ModelRead) { +export type SelectableModel = (ModelRead | ModelPreviewRead) & { + id?: number | string; + connection_id?: number; +}; + +export function modelLabel(model: SelectableModel) { return model.display_name || model.model_id; } -export function capability(model: ModelRead, key: ModelCapabilityFilter) { +export function capability(model: SelectableModel, key: ModelCapabilityFilter) { if (key === "chat") return Boolean(model.supports_chat); if (key === "vision") return Boolean(model.supports_image_input); return Boolean(model.supports_image_generation); } -export function capabilityLabels(model: ModelRead) { +export function capabilityLabels(model: SelectableModel) { return MODEL_CAPABILITY_FILTERS.filter((filter) => capability(model, filter.key)) .map((filter) => filter.label.toLowerCase()) .join(", "); diff --git a/surfsense_web/components/settings/model-connections/models-selection-panel.tsx b/surfsense_web/components/settings/model-connections/models-selection-panel.tsx index 20fbe862f..01ff0d1e7 100644 --- a/surfsense_web/components/settings/model-connections/models-selection-panel.tsx +++ b/surfsense_web/components/settings/model-connections/models-selection-panel.tsx @@ -4,17 +4,17 @@ import { Badge } from "@/components/ui/badge"; import { Button } from "@/components/ui/button"; import { Checkbox } from "@/components/ui/checkbox"; import { Input } from "@/components/ui/input"; -import type { ModelRead } from "@/contracts/types/model-connections.types"; import { capability, capabilityLabels, MODEL_CAPABILITY_FILTERS, type ModelCapabilityFilter, modelLabel, + type SelectableModel, } from "./model-utils"; interface ModelsSelectionPanelProps { - models: ModelRead[]; + models: SelectableModel[]; description?: string; emptyMessage?: string; manualInputPlaceholder?: string; @@ -25,8 +25,8 @@ interface ModelsSelectionPanelProps { isBulkUpdating?: boolean; onRefresh?: () => void; onAddManual?: (modelId: string) => void; - onToggleModel?: (model: ModelRead, enabled: boolean) => void; - onBulkToggle?: (models: ModelRead[], enabled: boolean) => void; + onToggleModel?: (model: SelectableModel, enabled: boolean) => void; + onBulkToggle?: (models: SelectableModel[], enabled: boolean) => void; } export function ModelsSelectionPanel({ @@ -166,7 +166,7 @@ export function ModelsSelectionPanel({
{filteredModels.map((model) => (
void; - connectedConnection?: ConnectionRead | null; - connectModels?: ModelRead[]; - isDiscoveringModels?: boolean; - isAddingManualModel?: boolean; - isUpdatingModel?: boolean; - isBulkUpdatingModels?: boolean; - onRefreshModels?: () => void; - onAddManualModel?: (modelId: string) => void; - onToggleModel?: (model: ModelRead, enabled: boolean) => void; - onBulkToggleModels?: (models: ModelRead[], enabled: boolean) => void; - onDone?: () => void; + previewModels?: SelectableModel[]; + isPreviewingModels?: boolean; + onPreviewModels?: (draft: ConnectionDraft) => void; + onAddPreviewModel?: (modelId: string) => void; + onTogglePreviewModel?: (model: SelectableModel, enabled: boolean) => void; + onBulkTogglePreviewModels?: (models: SelectableModel[], enabled: boolean) => void; } /** @@ -57,97 +50,93 @@ export function ProviderConnectDialog({ selectedProvider, isPending, onSubmit, - connectedConnection, - connectModels = [], - isDiscoveringModels = false, - isAddingManualModel = false, - isUpdatingModel = false, - isBulkUpdatingModels = false, - onRefreshModels, - onAddManualModel, - onToggleModel, - onBulkToggleModels, - onDone, + previewModels = [], + isPreviewingModels = false, + onPreviewModels, + onAddPreviewModel, + onTogglePreviewModel, + onBulkTogglePreviewModels, }: ProviderConnectDialogProps) { const meta = providerDisplay(provider); - const isModelSelectionStep = Boolean(connectedConnection); + const isAzure = provider === "azure"; + const isBedrock = provider === "bedrock"; + const isVertex = provider === "vertex_ai"; + const [currentDraft, setCurrentDraft] = useState({ + base_url: null, + api_key: null, + extra: {}, + }); + const [canSubmit, setCanSubmit] = useState(false); + + const handleDraftChange = useCallback((draft: ConnectionDraft, nextCanSubmit: boolean) => { + setCurrentDraft(draft); + setCanSubmit(nextCanSubmit); + }, []); const formProps: ProviderConnectFormProps = { provider, defaultBaseUrl: providerDefaultBaseUrl(provider, selectedProvider?.default_base_url), baseUrlRequired: Boolean(selectedProvider?.base_url_required), - isPending, - onCancel: () => onOpenChange(false), - onSubmit, + onDraftChange: handleDraftChange, }; + const modelDescription = (() => { + if (isAzure) { + return "Select the models to enable for Azure OpenAI"; + } + if (isBedrock) { + return "Select the models to enable for Amazon Bedrock"; + } + if (isVertex) { + return "Select the models to enable for Gemini"; + } + return "Select the models to enable for this provider"; + })(); + + const canRefreshModels = !isAzure && !isVertex && (!isBedrock || canSubmit); + return ( - +
{providerIcon(provider, "size-5")}
- - {isModelSelectionStep ? `Select ${meta.name} models` : `Connect ${meta.name}`} - - - {isModelSelectionStep - ? selectedProvider?.discovery === "static" - ? "Choose from known model IDs or add one manually." - : "Choose which discovered models should be available in this search space." - : meta.subtitle} - + Connect {meta.name} + {meta.subtitle}
- {isModelSelectionStep ? ( - <> -
- -
- - - - - ) : ( -
- {provider === "azure" ? ( - - ) : provider === "bedrock" ? ( - - ) : provider === "vertex_ai" ? ( - - ) : ( - - )} -
- )} +
+ {provider === "azure" ? ( + + ) : provider === "bedrock" ? ( + + ) : provider === "vertex_ai" ? ( + + ) : ( + + )} + + + + onPreviewModels?.(currentDraft) : undefined} + onAddManual={onAddPreviewModel} + onToggleModel={onTogglePreviewModel} + onBulkToggle={onBulkTogglePreviewModels} + /> +
+ onOpenChange(false)} + onSubmit={() => onSubmit(currentDraft)} + canSubmit={canSubmit} + isPending={isPending} + />
); diff --git a/surfsense_web/components/settings/model-connections/provider-metadata.tsx b/surfsense_web/components/settings/model-connections/provider-metadata.tsx index 0ca8ae419..73e873393 100644 --- a/surfsense_web/components/settings/model-connections/provider-metadata.tsx +++ b/surfsense_web/components/settings/model-connections/provider-metadata.tsx @@ -133,7 +133,5 @@ export interface ProviderConnectFormProps { provider: string; defaultBaseUrl: string; baseUrlRequired: boolean; - isPending: boolean; - onCancel: () => void; - onSubmit: (draft: ConnectionDraft) => void; + onDraftChange: (draft: ConnectionDraft, canSubmit: boolean) => void; } diff --git a/surfsense_web/components/settings/model-connections/vertex-connect-form.tsx b/surfsense_web/components/settings/model-connections/vertex-connect-form.tsx index 096d3df2e..1027742bc 100644 --- a/surfsense_web/components/settings/model-connections/vertex-connect-form.tsx +++ b/surfsense_web/components/settings/model-connections/vertex-connect-form.tsx @@ -1,4 +1,4 @@ -import { useState } from "react"; +import { useEffect, useState } from "react"; import { Input } from "@/components/ui/input"; import { Label } from "@/components/ui/label"; import { @@ -8,7 +8,6 @@ import { SelectTrigger, SelectValue, } from "@/components/ui/select"; -import { ConnectFormFooter } from "./connect-fields"; import { type ProviderConnectFormProps, VERTEX_AUTH_SERVICE_ACCOUNT, @@ -21,7 +20,7 @@ import { * credentials JSON file (read into a string); workload identity collects a * project id. Credentials ride along in `extra.litellm_params`. */ -export function VertexConnectForm({ isPending, onCancel, onSubmit }: ProviderConnectFormProps) { +export function VertexConnectForm({ onDraftChange }: ProviderConnectFormProps) { const [authMethod, setAuthMethod] = useState(VERTEX_AUTH_SERVICE_ACCOUNT); const [location, setLocation] = useState(VERTEX_DEFAULT_LOCATION); const [credentials, setCredentials] = useState(""); @@ -35,7 +34,7 @@ export function VertexConnectForm({ isPending, onCancel, onSubmit }: ProviderCon setCredentials(await file.text()); } - function handleSubmit() { + useEffect(() => { const params: Record = {}; if (location) params.vertex_location = location; if (authMethod === VERTEX_AUTH_SERVICE_ACCOUNT) { @@ -43,85 +42,77 @@ export function VertexConnectForm({ isPending, onCancel, onSubmit }: ProviderCon } else if (project) { params.vertex_project = project; } - onSubmit({ base_url: null, api_key: null, extra: { litellm_params: params } }); - } + onDraftChange({ base_url: null, api_key: null, extra: { litellm_params: params } }, canSubmit); + }, [authMethod, canSubmit, credentials, location, onDraftChange, project]); return ( - <> -
-
- - -
-
- - setLocation(event.target.value)} - placeholder={VERTEX_DEFAULT_LOCATION} - /> -

- Region where your Google Vertex AI models are hosted. -

-
- {authMethod === VERTEX_AUTH_SERVICE_ACCOUNT ? ( -
- - handleCredentialsFile(event.target.files?.[0])} - /> - -

- {credentials - ? "Credentials file loaded." - : "Attach your service account key JSON from Google Cloud."} -

-
- ) : ( -
- - setProject(event.target.value)} - placeholder="my-vertex-project" - /> -

- The GCP project where Vertex AI is enabled. -

-
- )} +
+
+ + +
+
+ + setLocation(event.target.value)} + placeholder={VERTEX_DEFAULT_LOCATION} + />

- Add Vertex AI model IDs from the provider's settings after connecting. + Region where your Google Vertex AI models are hosted.

- - + {authMethod === VERTEX_AUTH_SERVICE_ACCOUNT ? ( +
+ + handleCredentialsFile(event.target.files?.[0])} + /> + +

+ {credentials + ? "Credentials file loaded." + : "Attach your service account key JSON from Google Cloud."} +

+
+ ) : ( +
+ + setProject(event.target.value)} + placeholder="my-vertex-project" + /> +

+ The GCP project where Vertex AI is enabled. +

+
+ )} +

+ Add Vertex AI model IDs from the provider's settings after connecting. +

+
); } diff --git a/surfsense_web/contracts/types/model-connections.types.ts b/surfsense_web/contracts/types/model-connections.types.ts index c75f4c90a..134c740b2 100644 --- a/surfsense_web/contracts/types/model-connections.types.ts +++ b/surfsense_web/contracts/types/model-connections.types.ts @@ -40,6 +40,21 @@ export const connectionRead = z.object({ created_at: z.string().nullable().optional(), }); +export const modelSelection = z.object({ + model_id: z.string().min(1), + display_name: z.string().nullable().optional(), + source: z.union([modelSourceEnum, z.string()]).default("DISCOVERED"), + supports_chat: z.boolean().nullable().optional(), + max_input_tokens: z.number().nullable().optional(), + supports_image_input: z.boolean().nullable().optional(), + supports_tools: z.boolean().nullable().optional(), + supports_image_generation: z.boolean().nullable().optional(), + enabled: z.boolean().default(false), + metadata: z.record(z.string(), z.any()).default({}), +}); + +export const modelPreviewRead = modelSelection; + export const connectionCreateRequest = z.object({ provider: z.string().min(1), base_url: z.string().nullable().optional(), @@ -48,6 +63,7 @@ export const connectionCreateRequest = z.object({ scope: connectionScopeEnum.default("SEARCH_SPACE"), search_space_id: z.number().nullable().optional(), enabled: z.boolean().default(true), + models: z.array(modelSelection).default([]), }); export const connectionUpdateRequest = z.object({ @@ -105,9 +121,12 @@ export const modelProviderListResponse = z.array(modelProviderRead); export const connectionListResponse = z.array(connectionRead); export const modelListResponse = z.array(modelRead); +export const modelPreviewListResponse = z.array(modelPreviewRead); export type ConnectionScope = z.infer; export type ModelRead = z.infer; +export type ModelPreviewRead = z.infer; +export type ModelSelection = z.infer; export type ConnectionRead = z.infer; export type ConnectionCreateRequest = z.infer; export type ConnectionUpdateRequest = z.infer; diff --git a/surfsense_web/lib/apis/model-connections-api.service.ts b/surfsense_web/lib/apis/model-connections-api.service.ts index bd5aa1309..f463a27e7 100644 --- a/surfsense_web/lib/apis/model-connections-api.service.ts +++ b/surfsense_web/lib/apis/model-connections-api.service.ts @@ -7,6 +7,7 @@ import { connectionRead, connectionUpdateRequest, type ModelCreateRequest, + type ModelPreviewRead, type ModelProviderRead, type ModelRead, type ModelRoles, @@ -14,6 +15,7 @@ import { type ModelUpdateRequest, modelCreateRequest, modelListResponse, + modelPreviewListResponse, modelProviderListResponse, modelRead, modelRoles, @@ -76,6 +78,20 @@ class ModelConnectionsApiService { return baseApiService.post(`/api/v1/model-connections/${id}/discover`, modelListResponse); }; + previewModels = async (request: ConnectionCreateRequest): Promise => { + const parsed = connectionCreateRequest.safeParse(request); + if (!parsed.success) { + throw new ValidationError(parsed.error.issues.map((issue) => issue.message).join(", ")); + } + return baseApiService.post( + `/api/v1/model-connections/discover-preview`, + modelPreviewListResponse, + { + body: parsed.data, + } + ); + }; + addManualModel = async ( connectionId: number, request: ModelCreateRequest From 55f004e1da4bf6297aa8dcf3215ab0ca825c8051 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Fri, 12 Jun 2026 22:50:50 +0530 Subject: [PATCH 095/212] feat(model-connections): improve model discovery error handling and enhance UI components --- .../app/routes/model_connections_routes.py | 11 +++- .../app/services/model_connection_service.py | 53 ++++++++++++++----- .../models-selection-panel.tsx | 15 +++--- 3 files changed, 55 insertions(+), 24 deletions(-) diff --git a/surfsense_backend/app/routes/model_connections_routes.py b/surfsense_backend/app/routes/model_connections_routes.py index 2405843a7..474d376d3 100644 --- a/surfsense_backend/app/routes/model_connections_routes.py +++ b/surfsense_backend/app/routes/model_connections_routes.py @@ -32,6 +32,7 @@ from app.schemas import ( VerifyConnectionResponse, ) from app.services.model_connection_service import ( + ModelDiscoveryError, derive_capabilities, discover_models, persist_verification, @@ -313,7 +314,10 @@ async def preview_connection_models( search_space_id=data.search_space_id if data.scope == ConnectionScope.SEARCH_SPACE else None, user_id=user.id, ) - discovered = await discover_models(draft) + try: + discovered = await discover_models(draft) + except ModelDiscoveryError as exc: + raise HTTPException(status_code=400, detail=str(exc)) from exc return [_preview_model_read(item) for item in discovered] @@ -367,7 +371,10 @@ async def discover_connection_models( ): conn = await _load_connection(session, connection_id) await _assert_connection_access(session, user, conn, Permission.LLM_CONFIGS_CREATE.value) - discovered = await discover_models(conn) + try: + discovered = await discover_models(conn) + except ModelDiscoveryError as exc: + raise HTTPException(status_code=400, detail=str(exc)) from exc by_model_id = {model.model_id: model for model in conn.models} for item in discovered: db_model = by_model_id.get(item["model_id"]) diff --git a/surfsense_backend/app/services/model_connection_service.py b/surfsense_backend/app/services/model_connection_service.py index 7742e837e..c9ee2779f 100644 --- a/surfsense_backend/app/services/model_connection_service.py +++ b/surfsense_backend/app/services/model_connection_service.py @@ -31,6 +31,10 @@ class VerifyResult: message: str = "" +class ModelDiscoveryError(Exception): + """User-correctable discovery failure for provider configuration issues.""" + + def _auth_headers(conn: Connection) -> dict[str, str]: if not conn.api_key: return {} @@ -120,6 +124,23 @@ async def persist_verification(conn: Connection) -> VerifyResult: return result +def _discovery_error_message(conn: Connection, exc: httpx.HTTPError) -> str: + base_url = _base_url_or_default(conn) + if isinstance(exc, httpx.HTTPStatusError): + status_code = exc.response.status_code + if status_code in (401, 403): + return "Authentication failed while discovering models." + if status_code == 404: + spec = spec_for(conn.provider) + if spec.transport == Transport.OPENAI_COMPATIBLE: + return "OpenAI-compatible servers should expose /v1/models." + return "Model discovery endpoint returned 404." + return f"Model discovery failed with HTTP {status_code}." + if isinstance(exc, httpx.TimeoutException): + return f"Model discovery timed out: {exc}" + return _docker_hint(base_url, exc) + + def _allowlist(conn: Connection) -> set[str]: raw = (conn.extra or {}).get("model_ids") or [] return {str(item).strip() for item in raw if str(item).strip()} @@ -339,20 +360,23 @@ async def discover_models(conn: Connection) -> list[dict[str, Any]]: allowlist = _allowlist(conn) spec = spec_for(conn.provider) - if spec.discovery == "ollama": - results = await _ollama_tags_then_show(conn) - elif spec.discovery == "openrouter": - results = await _openrouter_models(conn) - elif spec.discovery == "anthropic_models": - results = await _discover_anthropic_models(conn) - elif spec.discovery == "openai_models": - results = await _discover_openai_shaped_models(conn, conn.base_url) - elif spec.discovery == "bedrock_models": - results = await _discover_bedrock_models(conn) - elif spec.discovery == "static": - results = _litellm_static_models(conn) - else: - results = [] + try: + if spec.discovery == "ollama": + results = await _ollama_tags_then_show(conn) + elif spec.discovery == "openrouter": + results = await _openrouter_models(conn) + elif spec.discovery == "anthropic_models": + results = await _discover_anthropic_models(conn) + elif spec.discovery == "openai_models": + results = await _discover_openai_shaped_models(conn, conn.base_url) + elif spec.discovery == "bedrock_models": + results = await _discover_bedrock_models(conn) + elif spec.discovery == "static": + results = _litellm_static_models(conn) + else: + results = [] + except httpx.HTTPError as exc: + raise ModelDiscoveryError(_discovery_error_message(conn, exc)) from exc if allowlist: results = [item for item in results if item["model_id"] in allowlist] @@ -376,6 +400,7 @@ async def test_model(conn: Connection, model: Model) -> VerifyResult: __all__ = [ + "ModelDiscoveryError", "VerifyResult", "derive_capabilities", "discover_models", diff --git a/surfsense_web/components/settings/model-connections/models-selection-panel.tsx b/surfsense_web/components/settings/model-connections/models-selection-panel.tsx index 01ff0d1e7..573049f6c 100644 --- a/surfsense_web/components/settings/model-connections/models-selection-panel.tsx +++ b/surfsense_web/components/settings/model-connections/models-selection-panel.tsx @@ -1,4 +1,4 @@ -import { RefreshCcw } from "lucide-react"; +import { RefreshCw } from "lucide-react"; import { useState } from "react"; import { Badge } from "@/components/ui/badge"; import { Button } from "@/components/ui/button"; @@ -32,7 +32,7 @@ interface ModelsSelectionPanelProps { export function ModelsSelectionPanel({ models, description = "Select models to make available for this provider.", - emptyMessage = "No models yet. Use the refresh button to discover models or add one manually.", + emptyMessage = "No models available.", manualInputPlaceholder = "Add a model ID manually", refreshLabel = "Refresh models", isRefreshing = false, @@ -86,14 +86,14 @@ export function ModelsSelectionPanel({ {onRefresh ? ( ) : null}
@@ -113,7 +113,6 @@ export function ModelsSelectionPanel({ placeholder={manualInputPlaceholder} />
) : null} -
+
{models.length === 0 ? (
{emptyMessage} From 9f6210ad089788304c1a2bbb1068e505cefe18c3 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Sat, 13 Jun 2026 00:12:04 +0530 Subject: [PATCH 096/212] feat(model-connections): add test preview functionality for model connections --- .../app/routes/model_connections_routes.py | 46 +++++++++- surfsense_backend/app/schemas/__init__.py | 20 ++++- .../app/schemas/model_connections.py | 4 + .../app/services/model_connection_service.py | 87 +++++++++++++++++-- .../app/services/model_resolver.py | 2 + .../app/services/provider_registry.py | 27 ++++-- .../model-connections-mutation.atoms.ts | 15 ++++ .../settings/model-connections-settings.tsx | 82 +++++++++-------- .../connection-settings-dialog.tsx | 63 +++++++++----- .../provider-connect-dialog.tsx | 4 +- .../types/model-connections.types.ts | 5 ++ .../lib/apis/model-connections-api.service.ts | 16 ++++ 12 files changed, 294 insertions(+), 77 deletions(-) diff --git a/surfsense_backend/app/routes/model_connections_routes.py b/surfsense_backend/app/routes/model_connections_routes.py index 474d376d3..76e4a3dfb 100644 --- a/surfsense_backend/app/routes/model_connections_routes.py +++ b/surfsense_backend/app/routes/model_connections_routes.py @@ -26,11 +26,13 @@ from app.schemas import ( ModelRead, ModelRolesRead, ModelRolesUpdate, - ModelSelection, ModelsBulkUpdate, + ModelSelection, + ModelTestPreview, ModelUpdate, VerifyConnectionResponse, ) +from app.services.model_capabilities import has_capability from app.services.model_connection_service import ( ModelDiscoveryError, derive_capabilities, @@ -38,7 +40,6 @@ from app.services.model_connection_service import ( persist_verification, test_model, ) -from app.services.model_capabilities import has_capability from app.services.provider_registry import REGISTRY from app.users import current_active_user from app.utils.rbac import check_permission @@ -321,6 +322,47 @@ async def preview_connection_models( return [_preview_model_read(item) for item in discovered] +@router.post("/model-connections/test-preview", response_model=VerifyConnectionResponse) +async def test_preview_connection_model( + data: ModelTestPreview, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +): + if data.scope == ConnectionScope.SEARCH_SPACE and data.search_space_id is not None: + await check_permission( + session, + user, + data.search_space_id, + Permission.LLM_CONFIGS_CREATE.value, + "You don't have permission to create model connections in this search space", + ) + + model_id = data.model_id.strip() + if not model_id: + raise HTTPException(status_code=400, detail="model_id is required") + + draft = Connection( + provider=data.provider, + base_url=data.base_url, + api_key=data.api_key, + extra=data.extra or {}, + scope=data.scope, + enabled=data.enabled, + search_space_id=data.search_space_id if data.scope == ConnectionScope.SEARCH_SPACE else None, + user_id=user.id, + ) + model = Model( + connection_id=0, + model_id=model_id, + source=ModelSource.MANUAL, + enabled=True, + capabilities_override={}, + catalog={}, + ) + result = await test_model(draft, model) + return VerifyConnectionResponse(status=result.status, ok=result.ok, message=result.message) + + @router.put("/model-connections/{connection_id}", response_model=ConnectionRead) async def update_connection( connection_id: int, diff --git a/surfsense_backend/app/schemas/__init__.py b/surfsense_backend/app/schemas/__init__.py index efa448dcd..3c4fdfa83 100644 --- a/surfsense_backend/app/schemas/__init__.py +++ b/surfsense_backend/app/schemas/__init__.py @@ -54,8 +54,9 @@ from .model_connections import ( ModelRead, ModelRolesRead, ModelRolesUpdate, - ModelSelection, ModelsBulkUpdate, + ModelSelection, + ModelTestPreview, ModelUpdate, VerifyConnectionResponse, ) @@ -149,7 +150,7 @@ from .vision_llm import ( VisionLLMConfigUpdate, ) -__all__ = [ +__all__ = [ # Folder schemas "BulkDocumentMove", # Chat schemas (assistant-ui integration) @@ -159,6 +160,10 @@ __all__ = [ "ChunkCreate", "ChunkRead", "ChunkUpdate", + # Model connection schemas + "ConnectionCreate", + "ConnectionRead", + "ConnectionUpdate", "CreateCreditCheckoutSessionRequest", "CreateCreditCheckoutSessionResponse", "CreditPurchaseHistoryResponse", @@ -232,6 +237,16 @@ __all__ = [ "MembershipRead", "MembershipReadWithUser", "MembershipUpdate", + "ModelCreate", + "ModelPreviewRead", + "ModelProviderRead", + "ModelRead", + "ModelRolesRead", + "ModelRolesUpdate", + "ModelSelection", + "ModelTestPreview", + "ModelUpdate", + "ModelsBulkUpdate", "NewChatMessageAppend", "NewChatMessageCreate", "NewChatMessageRead", @@ -282,6 +297,7 @@ __all__ = [ "UserRead", "UserSearchSpaceAccess", "UserUpdate", + "VerifyConnectionResponse", # Video Presentation schemas "VideoPresentationBase", "VideoPresentationCreate", diff --git a/surfsense_backend/app/schemas/model_connections.py b/surfsense_backend/app/schemas/model_connections.py index 896532d6f..67d94f821 100644 --- a/surfsense_backend/app/schemas/model_connections.py +++ b/surfsense_backend/app/schemas/model_connections.py @@ -85,6 +85,10 @@ class ConnectionCreate(BaseModel): models: list[ModelSelection] = Field(default_factory=list) +class ModelTestPreview(ConnectionCreate): + model_id: str = Field(..., max_length=255) + + class ConnectionUpdate(BaseModel): provider: str | None = Field(None, max_length=100) base_url: str | None = Field(None, max_length=500) diff --git a/surfsense_backend/app/services/model_connection_service.py b/surfsense_backend/app/services/model_connection_service.py index c9ee2779f..fbfdd437f 100644 --- a/surfsense_backend/app/services/model_connection_service.py +++ b/surfsense_backend/app/services/model_connection_service.py @@ -15,7 +15,7 @@ import litellm from app.db import Connection, Model, ModelSource from app.services.model_resolver import ensure_v1, to_litellm from app.services.openrouter_model_normalizer import normalize_openrouter_models -from app.services.provider_registry import Transport, spec_for +from app.services.provider_registry import Transport, provider_label, spec_for logger = logging.getLogger(__name__) @@ -77,6 +77,68 @@ def _docker_hint(url: str | None, exc_or_status: Any) -> str: return raw +def _model_test_error(conn: Connection, model_id: str, exc: Exception) -> VerifyResult: + provider_name = provider_label(conn.provider) + raw = str(exc) + normalized = raw.lower() + exc_name = exc.__class__.__name__.lower() + status_code = getattr(exc, "status_code", None) + + logger.info( + "Model test failed for provider=%s model=%s: %s", + conn.provider, + model_id, + raw, + ) + + if status_code in (401, 403) or "authentication" in exc_name or "401" in normalized: + return VerifyResult( + "AUTH_FAILED", + False, + f"Authentication failed. Check your {provider_name} credentials and try again.", + ) + + if status_code == 404 or "notfound" in exc_name or "not found" in normalized: + if conn.provider == "azure": + message = ( + "Azure OpenAI deployment was not found. Check the deployment name, " + "API version, and endpoint." + ) + else: + message = f"Model '{model_id}' was not found on {provider_name}." + return VerifyResult("NOT_FOUND", False, message) + + if status_code == 429 or "ratelimit" in exc_name or "rate limit" in normalized: + return VerifyResult( + "RATE_LIMITED", + False, + f"{provider_name} rate limited the model test. Try again later.", + ) + + if "timeout" in exc_name or "timed out" in normalized: + return VerifyResult( + "TIMEOUT", + False, + f"{provider_name} did not respond in time. Check the endpoint and try again.", + ) + + if "connection" in exc_name or "connect" in normalized: + return VerifyResult( + "UNREACHABLE", + False, + _docker_hint( + _base_url_or_default(conn), + f"Could not reach {provider_name}. Check the endpoint and try again.", + ), + ) + + return VerifyResult( + "UNREACHABLE", + False, + f"Could not test model '{model_id}' on {provider_name}. Check the credentials, endpoint, and model name.", + ) + + async def verify_connection(conn: Connection) -> VerifyResult: spec = spec_for(conn.provider) base_url = _base_url_or_default(conn) @@ -321,15 +383,24 @@ async def _discover_bedrock_models(conn: Connection) -> list[dict[str, Any]]: return [] def list_models() -> list[dict[str, Any]]: + import os + import boto3 - client_kwargs: dict[str, str] = {"region_name": region_name} - if params.get("aws_access_key_id"): - client_kwargs["aws_access_key_id"] = params["aws_access_key_id"] - if params.get("aws_secret_access_key"): - client_kwargs["aws_secret_access_key"] = params["aws_secret_access_key"] + if bearer_token := params.get("aws_bearer_token_bedrock"): + try: + os.environ["AWS_BEARER_TOKEN_BEDROCK"] = bearer_token + client = boto3.client("bedrock", region_name=region_name) + finally: + os.environ.pop("AWS_BEARER_TOKEN_BEDROCK", None) + else: + client_kwargs: dict[str, str] = {"region_name": region_name} + if params.get("aws_access_key_id"): + client_kwargs["aws_access_key_id"] = params["aws_access_key_id"] + if params.get("aws_secret_access_key"): + client_kwargs["aws_secret_access_key"] = params["aws_secret_access_key"] + client = boto3.client("bedrock", **client_kwargs) - client = boto3.client("bedrock", **client_kwargs) response = client.list_foundation_models() results: list[dict[str, Any]] = [] for item in response.get("modelSummaries", []): @@ -393,7 +464,7 @@ async def test_model(conn: Connection, model: Model) -> VerifyResult: **kwargs, ) except Exception as exc: - return VerifyResult("UNREACHABLE", False, str(exc)) + return _model_test_error(conn, model.model_id, exc) model.supports_chat = True return VerifyResult("OK", True, "Model test succeeded.") diff --git a/surfsense_backend/app/services/model_resolver.py b/surfsense_backend/app/services/model_resolver.py index ae6fd2877..599762824 100644 --- a/surfsense_backend/app/services/model_resolver.py +++ b/surfsense_backend/app/services/model_resolver.py @@ -55,6 +55,8 @@ def to_litellm( kwargs["api_version"] = api_version kwargs.update(extra.get("litellm_params", {})) kwargs.update(extra.get("kwargs", {})) + if provider == "bedrock" and (bearer_token := kwargs.pop("aws_bearer_token_bedrock", None)): + kwargs["api_key"] = bearer_token return model_string, kwargs diff --git a/surfsense_backend/app/services/provider_registry.py b/surfsense_backend/app/services/provider_registry.py index 98bfb63c1..2a58a3468 100644 --- a/surfsense_backend/app/services/provider_registry.py +++ b/surfsense_backend/app/services/provider_registry.py @@ -38,21 +38,24 @@ class ProviderSpec: default_base_url: str | None base_url_required: bool auth_style: AuthStyle + display_name: str | None = None REGISTRY: dict[str, ProviderSpec] = { "openai": ProviderSpec( - Transport.NATIVE, "openai", "openai_models", None, False, "bearer" + Transport.NATIVE, "openai", "openai_models", None, False, "bearer", "OpenAI" ), "anthropic": ProviderSpec( - Transport.NATIVE, "anthropic", "anthropic_models", None, False, "x-api-key" + Transport.NATIVE, "anthropic", "anthropic_models", None, False, "x-api-key", "Anthropic" + ), + "azure": ProviderSpec( + Transport.NATIVE, "azure", "static", None, True, "native", "Azure OpenAI" ), - "azure": ProviderSpec(Transport.NATIVE, "azure", "static", None, True, "native"), "vertex_ai": ProviderSpec( - Transport.NATIVE, "vertex_ai", "static", None, False, "native" + Transport.NATIVE, "vertex_ai", "static", None, False, "native", "Vertex AI" ), "bedrock": ProviderSpec( - Transport.NATIVE, "bedrock", "bedrock_models", None, False, "native" + Transport.NATIVE, "bedrock", "bedrock_models", None, False, "native", "Amazon Bedrock" ), "openrouter": ProviderSpec( Transport.OPENAI_COMPATIBLE, @@ -61,6 +64,7 @@ REGISTRY: dict[str, ProviderSpec] = { "https://openrouter.ai/api/v1", False, "bearer", + "OpenRouter", ), "openai_compatible": ProviderSpec( Transport.OPENAI_COMPATIBLE, @@ -69,6 +73,7 @@ REGISTRY: dict[str, ProviderSpec] = { None, True, "bearer", + "OpenAI-compatible provider", ), "lm_studio": ProviderSpec( Transport.OPENAI_COMPATIBLE, @@ -77,6 +82,7 @@ REGISTRY: dict[str, ProviderSpec] = { "http://localhost:1234/v1", True, "bearer", + "LM Studio", ), "ollama_chat": ProviderSpec( Transport.OLLAMA, @@ -85,6 +91,7 @@ REGISTRY: dict[str, ProviderSpec] = { "http://localhost:11434", True, "none", + "Ollama", ), } @@ -96,4 +103,12 @@ def spec_for(provider: str | None) -> ProviderSpec: ) -__all__ = ["REGISTRY", "ProviderSpec", "Transport", "spec_for"] +def provider_label(provider: str | None) -> str: + provider_key = (provider or "").strip() + spec = spec_for(provider_key) + if spec.display_name: + return spec.display_name + return provider_key.replace("_", " ").title() if provider_key else "Provider" + + +__all__ = ["REGISTRY", "ProviderSpec", "Transport", "provider_label", "spec_for"] diff --git a/surfsense_web/atoms/model-connections/model-connections-mutation.atoms.ts b/surfsense_web/atoms/model-connections/model-connections-mutation.atoms.ts index ea91c6483..00f8fa9ad 100644 --- a/surfsense_web/atoms/model-connections/model-connections-mutation.atoms.ts +++ b/surfsense_web/atoms/model-connections/model-connections-mutation.atoms.ts @@ -7,6 +7,7 @@ import type { ModelPreviewRead, ModelRead, ModelRoles, + ModelTestPreviewRequest, ModelsBulkUpdateRequest, ModelUpdateRequest, VerifyConnectionResponse, @@ -114,6 +115,20 @@ export const previewConnectionModelsMutationAtom = atomWithMutation(() => { }; }); +export const testPreviewModelMutationAtom = atomWithMutation(() => { + return { + mutationKey: ["model-connections", "test-preview"], + mutationFn: (request: ModelTestPreviewRequest) => + modelConnectionsApiService.testPreviewModel(request), + onSuccess: (result: VerifyConnectionResponse) => { + if (!result.ok) { + toast.error(result.message || "Model test failed"); + } + }, + onError: (error: Error) => toast.error(error.message || "Failed to test model"), + }; +}); + export const addManualModelMutationAtom = atomWithMutation((get) => { const searchSpaceId = Number(get(activeSearchSpaceIdAtom)); return { diff --git a/surfsense_web/components/settings/model-connections-settings.tsx b/surfsense_web/components/settings/model-connections-settings.tsx index 6c3d1a411..cf00ac6c9 100644 --- a/surfsense_web/components/settings/model-connections-settings.tsx +++ b/surfsense_web/components/settings/model-connections-settings.tsx @@ -1,12 +1,14 @@ "use client"; import { useAtom, useAtomValue } from "jotai"; -import { CheckCircle2, Trash2, XCircle } from "lucide-react"; +import { Trash2 } from "lucide-react"; import { useState } from "react"; +import { toast } from "sonner"; import { createModelConnectionMutationAtom, deleteModelConnectionMutationAtom, previewConnectionModelsMutationAtom, + testPreviewModelMutationAtom, updateModelRolesMutationAtom, } from "@/atoms/model-connections/model-connections-mutation.atoms"; import { @@ -53,24 +55,6 @@ import { providerIcon, } from "./model-connections/provider-metadata"; -function StatusBadge({ connection }: { connection: ConnectionRead }) { - if (connection.last_status === "OK") { - return ( - - Healthy - - ); - } - if (connection.last_status) { - return ( - - {connection.last_status} - - ); - } - return Not tested; -} - function flattenModels(connections: ConnectionRead[]) { return connections.flatMap((connection) => connection.models.map((model) => ({ @@ -110,7 +94,6 @@ function ConnectionCard({ connection }: { connection: ConnectionRead }) {
- @@ -156,6 +139,7 @@ export function ModelConnectionsSettings({ searchSpaceId }: { searchSpaceId: num const [{ data: roles }] = useAtom(modelRolesAtom); const createConnection = useAtomValue(createModelConnectionMutationAtom); const previewModels = useAtomValue(previewConnectionModelsMutationAtom); + const testPreviewModel = useAtomValue(testPreviewModelMutationAtom); const updateRoles = useAtomValue(updateModelRolesMutationAtom); const [isAddProviderOpen, setIsAddProviderOpen] = useState(false); @@ -220,9 +204,7 @@ export function ModelConnectionsSettings({ searchSpaceId }: { searchSpaceId: num }); } - // Each provider connect form builds its own credential payload; the backend - // resolver (`to_litellm`) forwards `extra.litellm_params` straight to LiteLLM. - function handleCreate(draft: ConnectionDraft) { + function connectionModelsForDraft(draft: ConnectionDraft) { const models = [...connectModels]; if (draft.seedModelId && !models.some((model) => model.model_id === draft.seedModelId)) { models.push({ @@ -233,22 +215,46 @@ export function ModelConnectionsSettings({ searchSpaceId }: { searchSpaceId: num metadata: {}, }); } + return models; + } - createConnection.mutate( + function representativeTestModel(models: ModelSelection[]) { + const enabledModels = models.filter((model) => model.enabled); + return enabledModels.find((model) => capability(model, "chat")) ?? enabledModels[0]; + } + + // Each provider connect form builds its own credential payload; the backend + // resolver (`to_litellm`) forwards `extra.litellm_params` straight to LiteLLM. + function handleCreate(draft: ConnectionDraft) { + const models = connectionModelsForDraft(draft); + const testModel = representativeTestModel(models); + if (!testModel) { + toast.error("Select at least one model before connecting"); + return; + } + + const request = { + provider, + base_url: draft.base_url, + api_key: draft.api_key, + scope: "SEARCH_SPACE" as const, + search_space_id: searchSpaceId, + extra: draft.extra, + enabled: true, + models, + }; + + testPreviewModel.mutate( + { ...request, model_id: testModel.model_id }, { - provider, - base_url: draft.base_url, - api_key: draft.api_key, - scope: "SEARCH_SPACE", - search_space_id: searchSpaceId, - extra: draft.extra, - enabled: true, - models, - }, - { - onSuccess: () => { - setIsAddProviderOpen(false); - resetConnectState(); + onSuccess: (result) => { + if (!result.ok) return; + createConnection.mutate(request, { + onSuccess: () => { + setIsAddProviderOpen(false); + resetConnectState(); + }, + }); }, } ); @@ -380,7 +386,7 @@ export function ModelConnectionsSettings({ searchSpaceId }: { searchSpaceId: num onOpenChange={handleConnectOpenChange} provider={provider} selectedProvider={selectedProvider} - isPending={createConnection.isPending} + isPending={createConnection.isPending || testPreviewModel.isPending} onSubmit={handleCreate} previewModels={connectModels} isPreviewingModels={previewModels.isPending} diff --git a/surfsense_web/components/settings/model-connections/connection-settings-dialog.tsx b/surfsense_web/components/settings/model-connections/connection-settings-dialog.tsx index d0f8e6c16..badddb8d7 100644 --- a/surfsense_web/components/settings/model-connections/connection-settings-dialog.tsx +++ b/surfsense_web/components/settings/model-connections/connection-settings-dialog.tsx @@ -5,9 +5,9 @@ import { addManualModelMutationAtom, bulkUpdateModelsMutationAtom, discoverConnectionModelsMutationAtom, + testPreviewModelMutationAtom, updateModelConnectionMutationAtom, updateModelMutationAtom, - verifyModelConnectionMutationAtom, } from "@/atoms/model-connections/model-connections-mutation.atoms"; import { Button } from "@/components/ui/button"; import { @@ -26,7 +26,7 @@ import type { ConnectionRead, ConnectionUpdateRequest, } from "@/contracts/types/model-connections.types"; -import type { SelectableModel } from "./model-utils"; +import { capability, type SelectableModel } from "./model-utils"; import { ModelsSelectionPanel } from "./models-selection-panel"; import { providerIcon } from "./provider-metadata"; @@ -39,8 +39,8 @@ export function ConnectionSettingsDialog({ connection, providerLabel, }: ConnectionSettingsDialogProps) { - const verifyConnection = useAtomValue(verifyModelConnectionMutationAtom); const discoverModels = useAtomValue(discoverConnectionModelsMutationAtom); + const testPreviewModel = useAtomValue(testPreviewModelMutationAtom); const updateConnection = useAtomValue(updateModelConnectionMutationAtom); const addManualModel = useAtomValue(addManualModelMutationAtom); const updateModel = useAtomValue(updateModelMutationAtom); @@ -81,11 +81,45 @@ export function ConnectionSettingsDialog({ if (apiKeyDraft.trim() !== (connection.api_key ?? "")) { data.api_key = apiKeyDraft.trim() || null; } + const apiKeyForTest = Object.hasOwn(data, "api_key") + ? (data.api_key ?? null) + : (connection.api_key ?? null); - updateConnection.mutate( - { id: connection.id, data }, + const enabledModels = connection.models.filter((model) => model.enabled); + const testModel = + enabledModels.find((model) => capability(model, "chat")) ?? enabledModels[0]; + if (!testModel) { + updateConnection.mutate( + { id: connection.id, data }, + { + onSuccess: () => setApiKeyDraft(""), + } + ); + return; + } + + testPreviewModel.mutate( { - onSuccess: () => setApiKeyDraft(""), + provider: connection.provider, + base_url: data.base_url, + api_key: apiKeyForTest, + scope: "SEARCH_SPACE", + search_space_id: connection.search_space_id, + extra: connection.extra ?? {}, + enabled: connection.enabled, + models: [], + model_id: testModel.model_id, + }, + { + onSuccess: (result) => { + if (!result.ok) return; + updateConnection.mutate( + { id: connection.id, data }, + { + onSuccess: () => setApiKeyDraft(""), + } + ); + }, } ); } @@ -219,26 +253,15 @@ export function ConnectionSettingsDialog({ onBulkToggle={handleBulkToggle} /> - {connection.last_status && connection.last_status !== "OK" ? ( -

- {connection.last_error || "Could not list models."} Chat may still work; add model - IDs manually if discovery is unavailable. -

- ) : null}
- diff --git a/surfsense_web/components/settings/model-connections/provider-connect-dialog.tsx b/surfsense_web/components/settings/model-connections/provider-connect-dialog.tsx index 315b9d3fa..51263d5f5 100644 --- a/surfsense_web/components/settings/model-connections/provider-connect-dialog.tsx +++ b/surfsense_web/components/settings/model-connections/provider-connect-dialog.tsx @@ -94,6 +94,8 @@ export function ProviderConnectDialog({ })(); const canRefreshModels = !isAzure && !isVertex && (!isBedrock || canSubmit); + const hasEnabledModel = previewModels.some((model) => model.enabled) || Boolean(currentDraft.seedModelId); + const canConnect = canSubmit && hasEnabledModel; return ( @@ -134,7 +136,7 @@ export function ProviderConnectDialog({ onOpenChange(false)} onSubmit={() => onSubmit(currentDraft)} - canSubmit={canSubmit} + canSubmit={canConnect} isPending={isPending} /> diff --git a/surfsense_web/contracts/types/model-connections.types.ts b/surfsense_web/contracts/types/model-connections.types.ts index 134c740b2..16db93868 100644 --- a/surfsense_web/contracts/types/model-connections.types.ts +++ b/surfsense_web/contracts/types/model-connections.types.ts @@ -66,6 +66,10 @@ export const connectionCreateRequest = z.object({ models: z.array(modelSelection).default([]), }); +export const modelTestPreviewRequest = connectionCreateRequest.extend({ + model_id: z.string().min(1), +}); + export const connectionUpdateRequest = z.object({ provider: z.string().nullable().optional(), base_url: z.string().nullable().optional(), @@ -129,6 +133,7 @@ export type ModelPreviewRead = z.infer; export type ModelSelection = z.infer; export type ConnectionRead = z.infer; export type ConnectionCreateRequest = z.infer; +export type ModelTestPreviewRequest = z.infer; export type ConnectionUpdateRequest = z.infer; export type ModelCreateRequest = z.infer; export type ModelUpdateRequest = z.infer; diff --git a/surfsense_web/lib/apis/model-connections-api.service.ts b/surfsense_web/lib/apis/model-connections-api.service.ts index f463a27e7..d875255ad 100644 --- a/surfsense_web/lib/apis/model-connections-api.service.ts +++ b/surfsense_web/lib/apis/model-connections-api.service.ts @@ -11,6 +11,7 @@ import { type ModelProviderRead, type ModelRead, type ModelRoles, + type ModelTestPreviewRequest, type ModelsBulkUpdateRequest, type ModelUpdateRequest, modelCreateRequest, @@ -19,6 +20,7 @@ import { modelProviderListResponse, modelRead, modelRoles, + modelTestPreviewRequest, modelsBulkUpdateRequest, modelUpdateRequest, type VerifyConnectionResponse, @@ -92,6 +94,20 @@ class ModelConnectionsApiService { ); }; + testPreviewModel = async (request: ModelTestPreviewRequest): Promise => { + const parsed = modelTestPreviewRequest.safeParse(request); + if (!parsed.success) { + throw new ValidationError(parsed.error.issues.map((issue) => issue.message).join(", ")); + } + return baseApiService.post( + `/api/v1/model-connections/test-preview`, + verifyConnectionResponse, + { + body: parsed.data, + } + ); + }; + addManualModel = async ( connectionId: number, request: ModelCreateRequest From e77b0c5d4eaf0fd83493fbe273d619aa23e7f79e Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Sat, 13 Jun 2026 00:25:53 +0530 Subject: [PATCH 097/212] feat(icons): add Azure, Claude, and LM Studio icons; update Bedrock icon and provider metadata --- .../components/icons/providers/azure.svg | 1 + .../components/icons/providers/bedrock.svg | 2 +- .../components/icons/providers/claude.svg | 1 + .../components/icons/providers/index.ts | 3 +++ .../components/icons/providers/lm-studio.svg | 21 +++++++++++++++++++ .../components/icons/providers/vertexai.svg | 2 +- .../model-connections/provider-metadata.tsx | 6 +++--- surfsense_web/lib/provider-icons.tsx | 14 +++++++++---- 8 files changed, 41 insertions(+), 9 deletions(-) create mode 100644 surfsense_web/components/icons/providers/azure.svg create mode 100644 surfsense_web/components/icons/providers/claude.svg create mode 100644 surfsense_web/components/icons/providers/lm-studio.svg diff --git a/surfsense_web/components/icons/providers/azure.svg b/surfsense_web/components/icons/providers/azure.svg new file mode 100644 index 000000000..ba80f55ca --- /dev/null +++ b/surfsense_web/components/icons/providers/azure.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/surfsense_web/components/icons/providers/bedrock.svg b/surfsense_web/components/icons/providers/bedrock.svg index 195aa6594..cde500c0d 100644 --- a/surfsense_web/components/icons/providers/bedrock.svg +++ b/surfsense_web/components/icons/providers/bedrock.svg @@ -1 +1 @@ - \ No newline at end of file + \ No newline at end of file diff --git a/surfsense_web/components/icons/providers/claude.svg b/surfsense_web/components/icons/providers/claude.svg new file mode 100644 index 000000000..8d732d5b0 --- /dev/null +++ b/surfsense_web/components/icons/providers/claude.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/surfsense_web/components/icons/providers/index.ts b/surfsense_web/components/icons/providers/index.ts index aefa2a053..8275595b1 100644 --- a/surfsense_web/components/icons/providers/index.ts +++ b/surfsense_web/components/icons/providers/index.ts @@ -1,8 +1,10 @@ export { default as Ai21Icon } from "./ai21.svg"; export { default as AnthropicIcon } from "./anthropic.svg"; export { default as AnyscaleIcon } from "./anyscale.svg"; +export { default as AzureIcon } from "./azure.svg"; export { default as BedrockIcon } from "./bedrock.svg"; export { default as CerebrasIcon } from "./cerebras.svg"; +export { default as ClaudeIcon } from "./claude.svg"; export { default as CohereIcon } from "./cohere.svg"; export { default as CometApiIcon } from "./cometapi.svg"; export { default as DatabricksIcon } from "./dbrx.svg"; @@ -15,6 +17,7 @@ export { default as GroqIcon } from "./groq.svg"; export { default as HuggingFaceIcon } from "./huggingface.svg"; export { default as MiniMaxIcon } from "./minimax.svg"; export { default as MistralIcon } from "./mistral.svg"; +export { default as LmStudioIcon } from "./lm-studio.svg"; export { default as MoonshotIcon } from "./moonshot.svg"; export { default as NscaleIcon } from "./nscale.svg"; export { default as OllamaIcon } from "./ollama.svg"; diff --git a/surfsense_web/components/icons/providers/lm-studio.svg b/surfsense_web/components/icons/providers/lm-studio.svg new file mode 100644 index 000000000..b6ae7db3e --- /dev/null +++ b/surfsense_web/components/icons/providers/lm-studio.svg @@ -0,0 +1,21 @@ + + + + + + + + + + + + + + + + + + + + + diff --git a/surfsense_web/components/icons/providers/vertexai.svg b/surfsense_web/components/icons/providers/vertexai.svg index 45adce83b..e46a3ca0f 100644 --- a/surfsense_web/components/icons/providers/vertexai.svg +++ b/surfsense_web/components/icons/providers/vertexai.svg @@ -1 +1 @@ - \ No newline at end of file + \ No newline at end of file diff --git a/surfsense_web/components/settings/model-connections/provider-metadata.tsx b/surfsense_web/components/settings/model-connections/provider-metadata.tsx index 73e873393..8b8a877b9 100644 --- a/surfsense_web/components/settings/model-connections/provider-metadata.tsx +++ b/surfsense_web/components/settings/model-connections/provider-metadata.tsx @@ -19,12 +19,12 @@ export const PROVIDER_DISPLAY: Record< anthropic: { name: "Claude", subtitle: "Anthropic", - iconKey: "anthropic", + iconKey: "claude", defaultBaseUrl: "https://api.anthropic.com/v1", }, - azure: { name: "Azure OpenAI", subtitle: "Microsoft Azure", iconKey: "azure_openai" }, + azure: { name: "Azure OpenAI", subtitle: "Microsoft Azure", iconKey: "azure" }, bedrock: { name: "Amazon Bedrock", subtitle: "AWS", iconKey: "bedrock" }, - lm_studio: { name: "LM Studio", subtitle: "LM Studio", iconKey: "custom" }, + lm_studio: { name: "LM Studio", subtitle: "LM Studio", iconKey: "lm_studio" }, ollama_chat: { name: "Ollama", subtitle: "Ollama", iconKey: "ollama" }, openai: { name: "GPT", diff --git a/surfsense_web/lib/provider-icons.tsx b/surfsense_web/lib/provider-icons.tsx index 3bb310904..4b2a4dfbe 100644 --- a/surfsense_web/lib/provider-icons.tsx +++ b/surfsense_web/lib/provider-icons.tsx @@ -1,10 +1,11 @@ import { Cpu, Shuffle } from "lucide-react"; import { Ai21Icon, - AnthropicIcon, AnyscaleIcon, + AzureIcon, BedrockIcon, CerebrasIcon, + ClaudeIcon, CloudflareIcon, CohereIcon, CometApiIcon, @@ -16,6 +17,7 @@ import { GitHubModelsIcon, GroqIcon, HuggingFaceIcon, + LmStudioIcon, MiniMaxIcon, MistralIcon, MoonshotIcon, @@ -54,12 +56,13 @@ export function getProviderIcon( case "ALIBABA_QWEN": return ; case "ANTHROPIC": - return ; + case "CLAUDE": + return ; case "ANYSCALE": return ; case "AZURE": case "AZURE_OPENAI": - return ; + return ; case "AWS_BEDROCK": case "BEDROCK": return ; @@ -72,7 +75,7 @@ export function getProviderIcon( case "COMETAPI": return ; case "CUSTOM": - return ; + return ; case "DATABRICKS": return ; case "DEEPINFRA": @@ -89,6 +92,8 @@ export function getProviderIcon( return ; case "HUGGINGFACE": return ; + case "LM_STUDIO": + return ; case "MINIMAX": return ; case "MISTRAL": @@ -98,6 +103,7 @@ export function getProviderIcon( case "NSCALE": return ; case "OLLAMA": + case "OLLAMA_CHAT": return ; case "OPENAI": return ; From 7a1bb2acd6fdce526aebe41d99a916ccad75c689 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Sat, 13 Jun 2026 00:46:53 +0530 Subject: [PATCH 098/212] feat(model-connections): refactor model roles UI --- .../settings/model-connections-settings.tsx | 132 +++++++++--------- .../connection-settings-dialog.tsx | 7 +- 2 files changed, 75 insertions(+), 64 deletions(-) diff --git a/surfsense_web/components/settings/model-connections-settings.tsx b/surfsense_web/components/settings/model-connections-settings.tsx index cf00ac6c9..c61368974 100644 --- a/surfsense_web/components/settings/model-connections-settings.tsx +++ b/surfsense_web/components/settings/model-connections-settings.tsx @@ -1,7 +1,7 @@ "use client"; import { useAtom, useAtomValue } from "jotai"; -import { Trash2 } from "lucide-react"; +import { Dot, Trash2 } from "lucide-react"; import { useState } from "react"; import { toast } from "sonner"; import { @@ -30,7 +30,6 @@ import { } from "@/components/ui/alert-dialog"; import { Badge } from "@/components/ui/badge"; import { Button } from "@/components/ui/button"; -import { Card, CardContent, CardDescription, CardHeader, CardTitle } from "@/components/ui/card"; import { Label } from "@/components/ui/label"; import { Select, @@ -78,7 +77,7 @@ function ConnectionCard({ connection }: { connection: ConnectionRead }) { return (
-
+
{providerIcon(connection.provider)} @@ -100,10 +99,11 @@ function ConnectionCard({ connection }: { connection: ConnectionRead }) { @@ -335,7 +335,11 @@ export function ModelConnectionsSettings({ searchSpaceId }: { searchSpaceId: num {providerIcon(model.provider)} - {modelLabel(model)} · {model.connectionName} + + {modelLabel(model)} + ); @@ -343,7 +347,66 @@ export function ModelConnectionsSettings({ searchSpaceId }: { searchSpaceId: num return (
-
+
+
+

Model Roles

+

+ Pick which enabled model powers chat, vision, and image generation for this search + space. +

+
+
+
+ + +
+
+ + +
+
+ + +
+
+
+ + + +

Add Provider

@@ -408,63 +471,6 @@ export function ModelConnectionsSettings({ searchSpaceId }: { searchSpaceId: num
) : null}
- - - - Model Roles - - Pick which enabled model powers chat, vision, and image generation for this search - space. - - - -
- - -
-
- - -
-
- - -
-
-
); } diff --git a/surfsense_web/components/settings/model-connections/connection-settings-dialog.tsx b/surfsense_web/components/settings/model-connections/connection-settings-dialog.tsx index badddb8d7..d20dbbdc6 100644 --- a/surfsense_web/components/settings/model-connections/connection-settings-dialog.tsx +++ b/surfsense_web/components/settings/model-connections/connection-settings-dialog.tsx @@ -157,7 +157,12 @@ export function ConnectionSettingsDialog({ return ( - From 45d27ba87992b6ec9c07a0df3d24357e7f830d2e Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Sat, 13 Jun 2026 01:37:12 +0530 Subject: [PATCH 099/212] feat(model-connections): enhance auto mode with auto pinning --- .../versions/160_add_model_connections.py | 39 ++++- surfsense_backend/app/db.py | 6 +- .../app/routes/model_connections_routes.py | 140 +++++++++++++++++- .../settings/model-connections-settings.tsx | 30 +++- 4 files changed, 196 insertions(+), 19 deletions(-) diff --git a/surfsense_backend/alembic/versions/160_add_model_connections.py b/surfsense_backend/alembic/versions/160_add_model_connections.py index 49d6315ca..2c35bd568 100644 --- a/surfsense_backend/alembic/versions/160_add_model_connections.py +++ b/surfsense_backend/alembic/versions/160_add_model_connections.py @@ -61,9 +61,21 @@ def _create_index_if_missing( op.create_index(index_name, table_name, columns, unique=False) -def _add_searchspace_column_if_missing(column_name: str) -> None: +def _add_searchspace_column_if_missing( + column_name: str, + *, + server_default: object | None = None, +) -> None: if not _column_exists("searchspaces", column_name): - op.add_column("searchspaces", sa.Column(column_name, sa.Integer(), nullable=True)) + op.add_column( + "searchspaces", + sa.Column( + column_name, + sa.Integer(), + nullable=True, + server_default=server_default, + ), + ) def _drop_column_if_exists(table_name: str, column_name: str) -> None: @@ -233,9 +245,26 @@ def upgrade() -> None: _create_index_if_missing("ix_models_model_id", "models", ["model_id"]) _create_index_if_missing("ix_models_billing_tier", "models", ["billing_tier"]) - _add_searchspace_column_if_missing("chat_model_id") - _add_searchspace_column_if_missing("image_gen_model_id") - _add_searchspace_column_if_missing("vision_model_id") + _add_searchspace_column_if_missing("chat_model_id", server_default=sa.text("0")) + _add_searchspace_column_if_missing("image_gen_model_id", server_default=sa.text("0")) + _add_searchspace_column_if_missing("vision_model_id", server_default=sa.text("0")) + for column_name in ("chat_model_id", "image_gen_model_id", "vision_model_id"): + op.alter_column( + "searchspaces", + column_name, + existing_type=sa.Integer(), + existing_nullable=True, + server_default=sa.text("0"), + ) + op.execute( + """ + UPDATE searchspaces + SET + chat_model_id = COALESCE(chat_model_id, 0), + image_gen_model_id = COALESCE(image_gen_model_id, 0), + vision_model_id = COALESCE(vision_model_id, 0) + """ + ) op.execute("DROP TYPE IF EXISTS connectionprotocol") diff --git a/surfsense_backend/app/db.py b/surfsense_backend/app/db.py index eeed9932d..728031fa0 100644 --- a/surfsense_backend/app/db.py +++ b/surfsense_backend/app/db.py @@ -1853,13 +1853,13 @@ class SearchSpace(BaseModel, TimestampMixin): # - Negative IDs: Global virtual models from global_llm_config.yaml # - Positive IDs: User/search-space models from the models table chat_model_id = Column( - Integer, nullable=True, default=0 + Integer, nullable=True, default=0, server_default="0" ) # For agent/chat operations, defaults to Auto mode image_gen_model_id = Column( - Integer, nullable=True, default=0 + Integer, nullable=True, default=0, server_default="0" ) # For image generation, defaults to Auto mode when eligible vision_model_id = Column( - Integer, nullable=True, default=0 + Integer, nullable=True, default=0, server_default="0" ) # For vision/screenshot analysis, defaults to Auto mode ai_file_sort_enabled = Column( diff --git a/surfsense_backend/app/routes/model_connections_routes.py b/surfsense_backend/app/routes/model_connections_routes.py index 76e4a3dfb..6db9aa9f3 100644 --- a/surfsense_backend/app/routes/model_connections_routes.py +++ b/surfsense_backend/app/routes/model_connections_routes.py @@ -131,6 +131,95 @@ def _default_model_for(models: list[Model], capability: str) -> int | None: return None +async def _load_role_model( + session: AsyncSession, + search_space_id: int, + model_id: int, +) -> Model | dict | None: + if model_id < 0: + return next( + (model for model in config.GLOBAL_MODELS if model.get("id") == model_id), + None, + ) + + result = await session.execute( + select(Model) + .options(selectinload(Model.connection)) + .where(Model.id == model_id) + ) + model = result.scalars().first() + if model is None or model.connection.search_space_id != search_space_id: + return None + return model + + +def _role_model_enabled(model: Model | dict) -> bool: + if isinstance(model, dict): + return bool(model.get("enabled", True)) + return bool(model.enabled and model.connection.enabled) + + +async def _validate_role_model_id( + session: AsyncSession, + *, + search_space_id: int, + model_id: int | None, + capability: str, +) -> int: + if model_id is None or model_id == 0: + return 0 + + model = await _load_role_model(session, search_space_id, model_id) + if model and _role_model_enabled(model) and has_capability(model, capability): + return model_id + + raise HTTPException( + status_code=400, + detail=f"Selected model is not available for {capability}", + ) + + +async def _resolve_role_model_id( + session: AsyncSession, + *, + search_space_id: int, + model_id: int | None, + capability: str, +) -> int: + try: + return await _validate_role_model_id( + session, + search_space_id=search_space_id, + model_id=model_id, + capability=capability, + ) + except HTTPException: + return 0 + + +async def _clear_invalid_roles(session: AsyncSession, search_space_id: int) -> SearchSpace: + search_space = await _get_search_space(session, search_space_id) + search_space.chat_model_id = await _resolve_role_model_id( + session, + search_space_id=search_space_id, + model_id=search_space.chat_model_id, + capability="chat", + ) + search_space.vision_model_id = await _resolve_role_model_id( + session, + search_space_id=search_space_id, + model_id=search_space.vision_model_id, + capability="vision", + ) + search_space.image_gen_model_id = await _resolve_role_model_id( + session, + search_space_id=search_space_id, + model_id=search_space.image_gen_model_id, + capability="image_gen", + ) + return search_space + + async def _default_unset_roles( session: AsyncSession, conn: Connection, @@ -372,9 +461,13 @@ async def update_connection( ): conn = await _load_connection(session, connection_id) await _assert_connection_access(session, user, conn, Permission.LLM_CONFIGS_UPDATE.value) + search_space_id = conn.search_space_id for key, value in data.model_dump(exclude_unset=True).items(): setattr(conn, key, value) await session.commit() + if search_space_id is not None: + await _clear_invalid_roles(session, search_space_id) + await session.commit() conn = await _load_connection(session, connection_id) return _connection_read(conn, list(conn.models)) @@ -387,8 +480,12 @@ async def delete_connection( ): conn = await _load_connection(session, connection_id) await _assert_connection_access(session, user, conn, Permission.LLM_CONFIGS_DELETE.value) + search_space_id = conn.search_space_id await session.delete(conn) await session.commit() + if search_space_id is not None: + await _clear_invalid_roles(session, search_space_id) + await session.commit() return {"status": "deleted"} @@ -439,6 +536,8 @@ async def discover_connection_models( await session.commit() conn = await _load_connection(session, connection_id) await _default_unset_roles(session, conn, list(conn.models)) + if conn.search_space_id is not None: + await _clear_invalid_roles(session, conn.search_space_id) await session.commit() conn = await _load_connection(session, connection_id) return [_model_read(model) for model in conn.models] @@ -476,7 +575,10 @@ async def add_manual_model( await session.refresh(model) conn = await _load_connection(session, connection_id) await _default_unset_roles(session, conn, list(conn.models)) + if conn.search_space_id is not None: + await _clear_invalid_roles(session, conn.search_space_id) await session.commit() + await session.refresh(model) return _model_read(model) @@ -489,6 +591,7 @@ async def bulk_update_models( ): conn = await _load_connection(session, connection_id) await _assert_connection_access(session, user, conn, Permission.LLM_CONFIGS_UPDATE.value) + search_space_id = conn.search_space_id model_ids = set(data.model_ids) await session.execute( @@ -498,6 +601,10 @@ async def bulk_update_models( ) await session.commit() session.expire_all() + if search_space_id is not None: + await _clear_invalid_roles(session, search_space_id) + await session.commit() + session.expire_all() result = await session.execute( select(Model) @@ -521,11 +628,16 @@ async def update_model( if not model: raise HTTPException(status_code=404, detail="Model not found") await _assert_connection_access(session, user, model.connection, Permission.LLM_CONFIGS_UPDATE.value) + search_space_id = model.connection.search_space_id update = data.model_dump(exclude_unset=True) for key, value in update.items(): setattr(model, key, value) await session.commit() await session.refresh(model) + if search_space_id is not None: + await _clear_invalid_roles(session, search_space_id) + await session.commit() + await session.refresh(model) return _model_read(model) @@ -560,7 +672,9 @@ async def get_model_roles( Permission.LLM_CONFIGS_CREATE.value, "You don't have permission to view model roles in this search space", ) - search_space = await _get_search_space(session, search_space_id) + search_space = await _clear_invalid_roles(session, search_space_id) + await session.commit() + await session.refresh(search_space) return ModelRolesRead( chat_model_id=search_space.chat_model_id, vision_model_id=search_space.vision_model_id, @@ -583,8 +697,28 @@ async def update_model_roles( "You don't have permission to update model roles in this search space", ) search_space = await _get_search_space(session, search_space_id) - for key, value in data.model_dump(exclude_unset=True).items(): - setattr(search_space, key, value) + updates = data.model_dump(exclude_unset=True) + if "chat_model_id" in updates: + search_space.chat_model_id = await _validate_role_model_id( + session, + search_space_id=search_space_id, + model_id=updates["chat_model_id"], + capability="chat", + ) + if "vision_model_id" in updates: + search_space.vision_model_id = await _validate_role_model_id( + session, + search_space_id=search_space_id, + model_id=updates["vision_model_id"], + capability="vision", + ) + if "image_gen_model_id" in updates: + search_space.image_gen_model_id = await _validate_role_model_id( + session, + search_space_id=search_space_id, + model_id=updates["image_gen_model_id"], + capability="image_gen", + ) await session.commit() await session.refresh(search_space) return ModelRolesRead( diff --git a/surfsense_web/components/settings/model-connections-settings.tsx b/surfsense_web/components/settings/model-connections-settings.tsx index c61368974..1f5c166b3 100644 --- a/surfsense_web/components/settings/model-connections-settings.tsx +++ b/surfsense_web/components/settings/model-connections-settings.tsx @@ -65,6 +65,11 @@ function flattenModels(connections: ConnectionRead[]) { ); } +function roleSelectValue(modelId: number | null | undefined, models: Array<{ id: number }>) { + if (!modelId) return "0"; + return models.some((model) => model.id === modelId) ? String(modelId) : "0"; +} + function ConnectionCard({ connection }: { connection: ConnectionRead }) { const deleteConnection = useAtomValue(deleteModelConnectionMutationAtom); @@ -349,8 +354,8 @@ export function ModelConnectionsSettings({ searchSpaceId }: { searchSpaceId: num
-

Model Roles

-

+

Model Roles

+

Pick which enabled model powers chat, vision, and image generation for this search space.

@@ -358,8 +363,12 @@ export function ModelConnectionsSettings({ searchSpaceId }: { searchSpaceId: num
+

+ Primary model for chat responses and agent tasks. You can also change it from the + chat. +

updateRoles.mutate({ vision_model_id: Number(value) })} > @@ -388,8 +401,9 @@ export function ModelConnectionsSettings({ searchSpaceId }: { searchSpaceId: num
+

Used when generating images in chat.

setSearch(event.target.value)} - placeholder="Search chat models..." - className="pl-9" + placeholder="Search chat models" + className="h-8 border-0 bg-transparent pl-6 text-sm shadow-none" />
-
+
@@ -217,17 +226,22 @@ export function ModelSelector({ type="button" variant="ghost" size="sm" - className={cn("h-8 gap-2 rounded-full px-3 text-muted-foreground", className)} + className={cn( + "h-8 min-w-0 gap-2 rounded-md px-3 text-muted-foreground transition-colors", + "hover:bg-foreground/10 hover:text-foreground", + "data-[state=open]:bg-foreground/10 data-[state=open]:text-foreground", + className + )} > {selected ? ( - getProviderIcon(selected.provider, { className: "size-4" }) + getProviderIcon(selected.provider, { className: "size-4 shrink-0" }) ) : ( - + )} - + {selected ? modelName(selected) : "Auto"} - + ); @@ -235,7 +249,8 @@ export function ModelSelector({ return ( {trigger} - + + Select Chat Model @@ -248,7 +263,7 @@ export function ModelSelector({ return ( {trigger} - + {content} From 7493ba93241a2d526f3bd8e42954b01c3879c189 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Sat, 13 Jun 2026 03:34:13 +0530 Subject: [PATCH 104/212] refactor(icons): replace Clock3 icon with AlarmClock for automations --- .../automations/components/automations-empty-state.tsx | 4 ++-- .../automations/components/automations-table.tsx | 4 ++-- .../components/layout/providers/LayoutDataProvider.tsx | 4 ++-- surfsense_web/components/new-chat/chat-example-prompts.tsx | 4 ++-- .../components/tool-ui/automation/create-automation.tsx | 6 +++--- surfsense_web/contracts/enums/toolIcons.tsx | 4 ++-- 6 files changed, 13 insertions(+), 13 deletions(-) diff --git a/surfsense_web/app/dashboard/[search_space_id]/automations/components/automations-empty-state.tsx b/surfsense_web/app/dashboard/[search_space_id]/automations/components/automations-empty-state.tsx index 70d9990f8..1ee71c636 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/automations/components/automations-empty-state.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/automations/components/automations-empty-state.tsx @@ -1,5 +1,5 @@ "use client"; -import { Clock3 } from "lucide-react"; +import { AlarmClock } from "lucide-react"; import Link from "next/link"; import { Button } from "@/components/ui/button"; @@ -18,7 +18,7 @@ export function AutomationsEmptyState({ searchSpaceId, canCreate }: AutomationsE return (
- +

No automations yet

diff --git a/surfsense_web/app/dashboard/[search_space_id]/automations/components/automations-table.tsx b/surfsense_web/app/dashboard/[search_space_id]/automations/components/automations-table.tsx index 727636b43..9ca510d44 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/automations/components/automations-table.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/automations/components/automations-table.tsx @@ -1,5 +1,5 @@ "use client"; -import { CalendarDays, Clock3, Info } from "lucide-react"; +import { CalendarDays, AlarmClock, Info } from "lucide-react"; import { Table, TableBody, TableHead, TableHeader, TableRow } from "@/components/ui/table"; import type { AutomationSummary } from "@/contracts/types/automation.types"; import { AutomationRow } from "./automation-row"; @@ -31,7 +31,7 @@ export function AutomationsTable({ - + Name diff --git a/surfsense_web/components/layout/providers/LayoutDataProvider.tsx b/surfsense_web/components/layout/providers/LayoutDataProvider.tsx index 5c62f6a7d..d2754594a 100644 --- a/surfsense_web/components/layout/providers/LayoutDataProvider.tsx +++ b/surfsense_web/components/layout/providers/LayoutDataProvider.tsx @@ -2,7 +2,7 @@ import { useQuery } from "@tanstack/react-query"; import { useAtom, useAtomValue, useSetAtom } from "jotai"; -import { AlertTriangle, Clock3, Inbox, LibraryBig } from "lucide-react"; +import { AlertTriangle, AlarmClock, Inbox, LibraryBig } from "lucide-react"; import { useParams, usePathname, useRouter } from "next/navigation"; import { useTranslations } from "next-intl"; import { useTheme } from "next-themes"; @@ -342,7 +342,7 @@ export function LayoutDataProvider({ searchSpaceId, children }: LayoutDataProvid { title: "Automations", url: `/dashboard/${searchSpaceId}/automations`, - icon: Clock3, + icon: AlarmClock, isActive: isAutomationsActive, }, isMobile diff --git a/surfsense_web/components/new-chat/chat-example-prompts.tsx b/surfsense_web/components/new-chat/chat-example-prompts.tsx index 4fdc32a92..28fa79c9d 100644 --- a/surfsense_web/components/new-chat/chat-example-prompts.tsx +++ b/surfsense_web/components/new-chat/chat-example-prompts.tsx @@ -1,7 +1,7 @@ "use client"; import { - Clock3, + AlarmClock, FilePlus2, Search, Settings2, @@ -22,7 +22,7 @@ interface ChatExamplePromptsProps { const CATEGORY_ICONS: Record = { search: Search, create: FilePlus2, - automate: Clock3, + automate: AlarmClock, tools: Settings2, }; diff --git a/surfsense_web/components/tool-ui/automation/create-automation.tsx b/surfsense_web/components/tool-ui/automation/create-automation.tsx index 644ccd822..5cffdeb6c 100644 --- a/surfsense_web/components/tool-ui/automation/create-automation.tsx +++ b/surfsense_web/components/tool-ui/automation/create-automation.tsx @@ -2,7 +2,7 @@ import type { ToolCallMessagePartProps } from "@assistant-ui/react"; import { useAtomValue } from "jotai"; -import { AlertCircle, Clock3, CornerDownLeftIcon, ExternalLink, Pencil } from "lucide-react"; +import { AlertCircle, AlarmClock, CornerDownLeftIcon, ExternalLink, Pencil } from "lucide-react"; import Link from "next/link"; import { useCallback, useEffect, useMemo, useRef, useState } from "react"; import { @@ -211,7 +211,7 @@ function ApprovalCard({ args, interruptData, onDecision }: ApprovalCardProps) {

- +

{phase === "rejected" @@ -404,7 +404,7 @@ function SavedCard({ result }: { result: SavedResult }) { return (

- +

Automation saved

{result.name}

diff --git a/surfsense_web/contracts/enums/toolIcons.tsx b/surfsense_web/contracts/enums/toolIcons.tsx index 496c26577..98f72796d 100644 --- a/surfsense_web/contracts/enums/toolIcons.tsx +++ b/surfsense_web/contracts/enums/toolIcons.tsx @@ -1,7 +1,7 @@ import { Brain, Calendar, - Clock3, + AlarmClock, FileEdit, FilePlus, FileText, @@ -47,7 +47,7 @@ const TOOL_ICONS: Record = { scrape_webpage: ScanLine, web_search: Globe, // Automations - create_automation: Clock3, + create_automation: AlarmClock, // Memory update_memory: Brain, // Filesystem (built-in deepagent + middleware) From 02070201fb7b2b18fa326e81b947691c6ae02d0c Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Sat, 13 Jun 2026 09:29:38 +0530 Subject: [PATCH 105/212] feat(model-connections): add hint support for API Base URL field and improve dialog accessibility --- .../model-connections/connect-fields.tsx | 17 ++++++++++------- .../model-connections/default-connect-form.tsx | 16 +++++++++++++++- .../provider-connect-dialog.tsx | 15 ++++++++++++--- 3 files changed, 37 insertions(+), 11 deletions(-) diff --git a/surfsense_web/components/settings/model-connections/connect-fields.tsx b/surfsense_web/components/settings/model-connections/connect-fields.tsx index 6ef7a8408..9febb8a1e 100644 --- a/surfsense_web/components/settings/model-connections/connect-fields.tsx +++ b/surfsense_web/components/settings/model-connections/connect-fields.tsx @@ -1,4 +1,5 @@ import { Eye, EyeOff } from "lucide-react"; +import type { ReactNode } from "react"; import { useState } from "react"; import { Button } from "@/components/ui/button"; import { DialogFooter } from "@/components/ui/dialog"; @@ -9,25 +10,27 @@ import { Spinner } from "@/components/ui/spinner"; interface ApiBaseUrlFieldProps { value: string; onChange: (value: string) => void; - optional?: boolean; /** Placeholder, typically the provider's prefilled default base URL. */ placeholder?: string; + hint?: ReactNode; } /** Shared API Base URL input. The prefilled default is passed in via `value`. */ -export function ApiBaseUrlField({ value, onChange, optional, placeholder }: ApiBaseUrlFieldProps) { +export function ApiBaseUrlField({ + value, + onChange, + placeholder, + hint, +}: ApiBaseUrlFieldProps) { return (
- + onChange(event.target.value)} placeholder={placeholder || "https://api.example.com/v1"} /> -

- Local URLs are tested from the backend container, so use host.docker.internal instead of - localhost. -

+ {hint ?

{hint}

: null}
); } diff --git a/surfsense_web/components/settings/model-connections/default-connect-form.tsx b/surfsense_web/components/settings/model-connections/default-connect-form.tsx index 768c0b5da..ce638f5e6 100644 --- a/surfsense_web/components/settings/model-connections/default-connect-form.tsx +++ b/surfsense_web/components/settings/model-connections/default-connect-form.tsx @@ -2,6 +2,19 @@ import { useEffect, useState } from "react"; import { ApiBaseUrlField, ApiKeyField } from "./connect-fields"; import type { ProviderConnectFormProps } from "./provider-metadata"; +function baseUrlHint(provider: string) { + if (provider === "ollama_chat" || provider === "lm_studio") { + return "For local servers, use host.docker.internal instead of localhost."; + } + if (provider === "openai_compatible") { + return "Enter the full endpoint URL."; + } + if (provider === "openai" || provider === "anthropic" || provider === "openrouter") { + return "Override only if you route through a proxy or gateway."; + } + return undefined; +} + /** * Connect form for OpenAI-compatible / native key providers (OpenAI, Anthropic, * OpenRouter, OpenAI-Compatible, LM Studio, Ollama, …). The base URL is @@ -16,6 +29,7 @@ export function DefaultConnectForm({ const [baseUrl, setBaseUrl] = useState(defaultBaseUrl); const [apiKey, setApiKey] = useState(""); const isOllama = provider === "ollama_chat"; + const hint = baseUrlHint(provider); const canSubmit = !(baseUrlRequired && !baseUrl.trim()); useEffect(() => { @@ -27,8 +41,8 @@ export function DefaultConnectForm({ (null); const [currentDraft, setCurrentDraft] = useState({ base_url: null, api_key: null, @@ -99,12 +100,20 @@ export function ProviderConnectDialog({ return ( - + { + event.preventDefault(); + titleRef.current?.focus(); + }} + >
{providerIcon(provider, "size-5")}
- Connect {meta.name} + + Connect {meta.name} + {meta.subtitle}
From 50668775f81a81e319d6cbb47f02a326bba92403 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Sat, 13 Jun 2026 12:14:17 +0530 Subject: [PATCH 106/212] feat(model-selector): enhance model selection with connection scope and free model indication --- .../components/new-chat/model-selector.tsx | 100 ++++++++++-------- 1 file changed, 58 insertions(+), 42 deletions(-) diff --git a/surfsense_web/components/new-chat/model-selector.tsx b/surfsense_web/components/new-chat/model-selector.tsx index 0f5f50849..90226dde5 100644 --- a/surfsense_web/components/new-chat/model-selector.tsx +++ b/surfsense_web/components/new-chat/model-selector.tsx @@ -1,8 +1,10 @@ "use client"; import { useAtom, useAtomValue } from "jotai"; -import { Check, ChevronDown, Cpu, ImageOff, Search, Settings2, Zap } from "lucide-react"; -import { useMemo, useState } from "react"; +import { Check, ChevronDown, Cpu, Search, Settings2, Zap } from "lucide-react"; +import { useRouter } from "next/navigation"; +import type { UIEvent } from "react"; +import { useCallback, useMemo, useState } from "react"; import { updateModelRolesMutationAtom } from "@/atoms/model-connections/model-connections-mutation.atoms"; import { globalModelConnectionsAtom, @@ -23,41 +25,30 @@ import { Input } from "@/components/ui/input"; import { Popover, PopoverContent, PopoverTrigger } from "@/components/ui/popover"; import { Spinner } from "@/components/ui/spinner"; import type { ConnectionRead, ModelRead } from "@/contracts/types/model-connections.types"; -import type { - GlobalImageGenConfig, - GlobalNewLLMConfig, - GlobalVisionLLMConfig, - ImageGenerationConfig, - NewLLMConfigPublic, - VisionLLMConfig, -} from "@/contracts/types/new-llm-config.types"; import { useIsMobile } from "@/hooks/use-mobile"; import { getProviderIcon } from "@/lib/provider-icons"; import { cn } from "@/lib/utils"; +import { providerDisplay } from "../settings/model-connections/provider-metadata"; interface ModelSelectorProps { - onEditLLM: (config: NewLLMConfigPublic | GlobalNewLLMConfig, isGlobal: boolean) => void; - onAddNewLLM: (provider?: string) => void; - onEditImage?: (config: ImageGenerationConfig | GlobalImageGenConfig, isGlobal: boolean) => void; - onAddNewImage?: (provider?: string) => void; - onEditVision?: (config: VisionLLMConfig | GlobalVisionLLMConfig, isGlobal: boolean) => void; - onAddNewVision?: (provider?: string) => void; + searchSpaceId: number; className?: string; } type ChatModel = ModelRead & { connectionId: number; connectionLabel: string; + connectionScope: string; provider: string; }; function modelName(model: ModelRead) { - return model.display_name || model.model_id; + return (model.display_name || model.model_id).replace(/\s+\(free\)$/i, ""); } function connectionLabel(connection: ConnectionRead) { - if (connection.scope === "GLOBAL") return "Hosted"; - return connection.provider; + if (connection.scope === "GLOBAL") return "Global"; + return providerDisplay(connection.provider).name; } function flattenChatModels(connections: ConnectionRead[]) { @@ -68,11 +59,16 @@ function flattenChatModels(connections: ConnectionRead[]) { ...model, connectionId: connection.id, connectionLabel: connectionLabel(connection), + connectionScope: connection.scope, provider: connection.provider, })) ); } +function isFreeGlobalModel(model: ChatModel) { + return model.connectionScope === "GLOBAL" && model.billing_tier?.toLowerCase() === "free"; +} + function groupedModels(models: ChatModel[]) { return models.reduce>((groups, model) => { const key = model.connectionLabel; @@ -83,23 +79,14 @@ function groupedModels(models: ChatModel[]) { } export function ModelSelector({ - onAddNewLLM, - onEditLLM, - onEditImage, - onAddNewImage, - onEditVision, - onAddNewVision, + searchSpaceId, className, }: ModelSelectorProps) { - void onEditLLM; - void onEditImage; - void onAddNewImage; - void onEditVision; - void onAddNewVision; - + const router = useRouter(); const isMobile = useIsMobile(); const [open, setOpen] = useState(false); const [search, setSearch] = useState(""); + const [scrollPos, setScrollPos] = useState<"top" | "middle" | "bottom">("top"); const [{ data: globalConnections = [], isLoading: globalLoading }] = useAtom( globalModelConnectionsAtom ); @@ -130,11 +117,18 @@ export function ModelSelector({ function manageModelConnections() { setOpen(false); - onAddNewLLM(); + router.push(`/dashboard/${searchSpaceId}/search-space-settings/models`); } + const handleScroll = useCallback((event: UIEvent) => { + const el = event.currentTarget; + const atTop = el.scrollTop <= 2; + const atBottom = el.scrollHeight - el.scrollTop - el.clientHeight <= 2; + setScrollPos(atTop ? "top" : atBottom ? "bottom" : "middle"); + }, []); + const content = ( -
+
@@ -146,7 +140,14 @@ export function ModelSelector({ />
-
+
@@ -228,6 +243,7 @@ export function ModelSelector({ size="sm" className={cn( "h-8 min-w-0 gap-2 rounded-md px-3 text-muted-foreground transition-colors", + "select-none", "hover:bg-foreground/10 hover:text-foreground", "data-[state=open]:bg-foreground/10 data-[state=open]:text-foreground", className @@ -263,7 +279,7 @@ export function ModelSelector({ return ( {trigger} - + {content} From bd4a04f2e72a9aebb7a5f614f83bbd43d97487f1 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Sat, 13 Jun 2026 12:45:43 +0530 Subject: [PATCH 107/212] feat(database-migrations): add migration to remove legacy model config tables and remove stale model connection code --- ...38_add_thread_auto_model_pinning_fields.py | 2 +- .../161_remove_legacy_model_configs.py | 270 +++ .../deliverables/tools/generate_image.py | 2 +- .../app/agents/chat/runtime/llm_config.py | 132 +- .../app/automations/services/model_policy.py | 12 +- surfsense_backend/app/config/__init__.py | 50 +- .../app/config/global_llm_config.example.yaml | 80 +- surfsense_backend/app/db.py | 287 +-- .../prompts/default_system_instructions.py | 4 +- .../system_prompt_composer/composer.py | 3 +- surfsense_backend/app/routes/__init__.py | 4 - .../app/routes/image_generation_routes.py | 291 +-- .../app/routes/model_connections_routes.py | 17 +- .../app/routes/new_llm_config_routes.py | 480 ----- .../app/routes/search_spaces_routes.py | 360 +--- .../app/routes/vision_llm_routes.py | 304 ---- surfsense_backend/app/schemas/__init__.py | 45 - .../app/schemas/image_generation.py | 165 +- .../app/schemas/new_llm_config.py | 256 --- surfsense_backend/app/schemas/vision_llm.py | 116 -- .../app/services/auto_model_pin_service.py | 20 +- .../app/services/billable_calls.py | 57 +- surfsense_backend/app/services/llm_service.py | 24 +- .../app/services/model_list_service.py | 2 +- .../openrouter_integration_service.py | 93 +- .../app/services/pricing_registration.py | 6 +- .../app/services/quality_score.py | 2 +- .../app/services/vision_llm_router_service.py | 160 -- .../app/services/vision_model_list_service.py | 134 -- .../scripts/verify_chat_image_capability.py | 43 +- .../builtin/agent_task/test_dependencies.py | 40 +- .../runtime/test_executor_action_ctx.py | 18 +- .../schemas/definition/test_envelope.py | 18 +- .../test_automation_service_policy.py | 120 +- .../automations/services/test_model_policy.py | 62 +- .../routes/test_byok_supports_image_input.py | 110 -- .../routes/test_global_configs_is_premium.py | 184 -- ...t_global_new_llm_configs_supports_image.py | 106 -- .../tests/unit/routes/test_image_gen_quota.py | 62 +- .../services/test_agent_billing_resolver.py | 232 +-- .../services/test_auto_model_pin_service.py | 135 +- .../test_image_gen_api_base_defense.py | 54 +- .../test_openrouter_integration_service.py | 93 +- .../services/test_pricing_registration.py | 74 - .../tests/unit/services/test_quality_score.py | 2 +- .../test_vision_llm_api_base_defense.py | 77 - surfsense_evals/README.md | 4 +- .../parser_compare/run_artifact.json | 2 +- .../src/surfsense_evals/core/cli.py | 73 +- .../core/clients/search_space.py | 116 +- .../src/surfsense_evals/core/config.py | 21 +- .../src/surfsense_evals/core/registry.py | 4 +- .../src/surfsense_evals/core/vision_llm.py | 4 +- .../suites/medical/medxpertqa/runner.py | 2 +- .../multimodal_doc/mmlongbench/runner.py | 2 +- .../multimodal_doc/parser_compare/runner.py | 2 +- .../suites/research/crag/runner.py | 2 +- .../suites/research/frames/runner.py | 2 +- surfsense_evals/tests/core/test_clients.py | 23 +- surfsense_evals/tests/core/test_config.py | 30 +- .../tests/test_integration_smoke.py | 2 +- .../image-models/page.tsx | 6 - .../search-space-settings/roles/page.tsx | 6 - .../vision-models/page.tsx | 6 - .../image-gen-config-mutation.atoms.ts | 96 - .../image-gen-config-query.atoms.ts | 33 - .../new-llm-config-mutation.atoms.ts | 132 -- .../new-llm-config-query.atoms.ts | 98 -- .../vision-llm-config-mutation.atoms.ts | 87 - .../vision-llm-config-query.atoms.ts | 51 - .../components/new-chat/chat-header.tsx | 153 +- .../settings/agent-model-manager.tsx | 423 ----- .../settings/image-model-manager.tsx | 489 ------ .../components/settings/llm-role-manager.tsx | 443 ----- .../settings/vision-model-manager.tsx | 486 ----- .../components/shared/image-config-dialog.tsx | 456 ----- .../components/shared/llm-config-form.tsx | 527 ------ .../components/shared/model-config-dialog.tsx | 339 ---- .../shared/vision-config-dialog.tsx | 478 ----- .../contracts/enums/image-gen-providers.ts | 105 -- surfsense_web/contracts/enums/llm-models.ts | 1558 ----------------- .../contracts/enums/llm-providers.ts | 197 --- .../contracts/enums/vision-providers.ts | 168 -- .../contracts/types/new-llm-config.types.ts | 476 ----- .../lib/apis/image-gen-config-api.service.ts | 81 - .../lib/apis/new-llm-config-api.service.ts | 178 -- .../lib/apis/vision-llm-config-api.service.ts | 63 - surfsense_web/lib/query-client/cache-keys.ts | 19 - surfsense_web/messages/en.json | 8 - surfsense_web/messages/es.json | 31 +- surfsense_web/messages/hi.json | 31 +- surfsense_web/messages/pt.json | 31 +- surfsense_web/messages/zh.json | 46 +- 93 files changed, 956 insertions(+), 11442 deletions(-) create mode 100644 surfsense_backend/alembic/versions/161_remove_legacy_model_configs.py delete mode 100644 surfsense_backend/app/routes/new_llm_config_routes.py delete mode 100644 surfsense_backend/app/routes/vision_llm_routes.py delete mode 100644 surfsense_backend/app/schemas/new_llm_config.py delete mode 100644 surfsense_backend/app/schemas/vision_llm.py delete mode 100644 surfsense_backend/app/services/vision_llm_router_service.py delete mode 100644 surfsense_backend/app/services/vision_model_list_service.py delete mode 100644 surfsense_backend/tests/unit/routes/test_byok_supports_image_input.py delete mode 100644 surfsense_backend/tests/unit/routes/test_global_configs_is_premium.py delete mode 100644 surfsense_backend/tests/unit/routes/test_global_new_llm_configs_supports_image.py delete mode 100644 surfsense_backend/tests/unit/services/test_vision_llm_api_base_defense.py delete mode 100644 surfsense_web/app/dashboard/[search_space_id]/search-space-settings/image-models/page.tsx delete mode 100644 surfsense_web/app/dashboard/[search_space_id]/search-space-settings/roles/page.tsx delete mode 100644 surfsense_web/app/dashboard/[search_space_id]/search-space-settings/vision-models/page.tsx delete mode 100644 surfsense_web/atoms/image-gen-config/image-gen-config-mutation.atoms.ts delete mode 100644 surfsense_web/atoms/image-gen-config/image-gen-config-query.atoms.ts delete mode 100644 surfsense_web/atoms/new-llm-config/new-llm-config-mutation.atoms.ts delete mode 100644 surfsense_web/atoms/new-llm-config/new-llm-config-query.atoms.ts delete mode 100644 surfsense_web/atoms/vision-llm-config/vision-llm-config-mutation.atoms.ts delete mode 100644 surfsense_web/atoms/vision-llm-config/vision-llm-config-query.atoms.ts delete mode 100644 surfsense_web/components/settings/agent-model-manager.tsx delete mode 100644 surfsense_web/components/settings/image-model-manager.tsx delete mode 100644 surfsense_web/components/settings/llm-role-manager.tsx delete mode 100644 surfsense_web/components/settings/vision-model-manager.tsx delete mode 100644 surfsense_web/components/shared/image-config-dialog.tsx delete mode 100644 surfsense_web/components/shared/llm-config-form.tsx delete mode 100644 surfsense_web/components/shared/model-config-dialog.tsx delete mode 100644 surfsense_web/components/shared/vision-config-dialog.tsx delete mode 100644 surfsense_web/contracts/enums/image-gen-providers.ts delete mode 100644 surfsense_web/contracts/enums/llm-models.ts delete mode 100644 surfsense_web/contracts/enums/llm-providers.ts delete mode 100644 surfsense_web/contracts/enums/vision-providers.ts delete mode 100644 surfsense_web/contracts/types/new-llm-config.types.ts delete mode 100644 surfsense_web/lib/apis/image-gen-config-api.service.ts delete mode 100644 surfsense_web/lib/apis/new-llm-config-api.service.ts delete mode 100644 surfsense_web/lib/apis/vision-llm-config-api.service.ts diff --git a/surfsense_backend/alembic/versions/138_add_thread_auto_model_pinning_fields.py b/surfsense_backend/alembic/versions/138_add_thread_auto_model_pinning_fields.py index fba621a0c..8c74b637b 100644 --- a/surfsense_backend/alembic/versions/138_add_thread_auto_model_pinning_fields.py +++ b/surfsense_backend/alembic/versions/138_add_thread_auto_model_pinning_fields.py @@ -4,7 +4,7 @@ Revision ID: 138 Revises: 137 Create Date: 2026-04-30 -Add a single thread-level column to persist the Auto (Fastest) model pin: +Add a single thread-level column to persist the Auto model pin: - pinned_llm_config_id: concrete resolved global LLM config id used for this thread. NULL means "no pin; Auto will resolve on next turn". diff --git a/surfsense_backend/alembic/versions/161_remove_legacy_model_configs.py b/surfsense_backend/alembic/versions/161_remove_legacy_model_configs.py new file mode 100644 index 000000000..2108d763c --- /dev/null +++ b/surfsense_backend/alembic/versions/161_remove_legacy_model_configs.py @@ -0,0 +1,270 @@ +"""remove legacy model config tables + +Revision ID: 161 +Revises: 160 +""" + +from collections.abc import Sequence + +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql +from sqlalchemy.types import TypeEngine + +from alembic import op + +revision: str = "161" +down_revision: str | None = "160" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +litellm_provider = postgresql.ENUM( + "OPENAI", + "ANTHROPIC", + "GOOGLE", + "AZURE_OPENAI", + "BEDROCK", + "VERTEX_AI", + "GROQ", + "COHERE", + "MISTRAL", + "DEEPSEEK", + "XAI", + "OPENROUTER", + "TOGETHER_AI", + "FIREWORKS_AI", + "REPLICATE", + "PERPLEXITY", + "OLLAMA", + "ALIBABA_QWEN", + "MOONSHOT", + "ZHIPU", + "ANYSCALE", + "DEEPINFRA", + "CEREBRAS", + "SAMBANOVA", + "AI21", + "CLOUDFLARE", + "DATABRICKS", + "COMETAPI", + "HUGGINGFACE", + "GITHUB_MODELS", + "MINIMAX", + "CUSTOM", + name="litellmprovider", + create_type=False, +) +image_gen_provider = postgresql.ENUM( + "OPENAI", + "AZURE_OPENAI", + "GOOGLE", + "VERTEX_AI", + "BEDROCK", + "RECRAFT", + "OPENROUTER", + "XINFERENCE", + "NSCALE", + name="imagegenprovider", + create_type=False, +) +vision_provider = postgresql.ENUM( + "OPENAI", + "ANTHROPIC", + "GOOGLE", + "AZURE_OPENAI", + "VERTEX_AI", + "BEDROCK", + "XAI", + "OPENROUTER", + "OLLAMA", + "GROQ", + "TOGETHER_AI", + "FIREWORKS_AI", + "DEEPSEEK", + "MISTRAL", + "CUSTOM", + name="visionprovider", + create_type=False, +) + + +def _table_exists(table_name: str) -> bool: + return table_name in sa.inspect(op.get_bind()).get_table_names() + + +def _column_exists(table_name: str, column_name: str) -> bool: + if not _table_exists(table_name): + return False + return column_name in { + column["name"] for column in sa.inspect(op.get_bind()).get_columns(table_name) + } + + +def _drop_column_if_exists(table_name: str, column_name: str) -> None: + if _column_exists(table_name, column_name): + op.drop_column(table_name, column_name) + + +def _rename_column_if_exists( + table_name: str, + old_column_name: str, + new_column_name: str, + *, + existing_type: TypeEngine, + existing_nullable: bool = True, +) -> None: + if _column_exists(table_name, old_column_name) and not _column_exists( + table_name, new_column_name + ): + op.alter_column( + table_name, + old_column_name, + new_column_name=new_column_name, + existing_type=existing_type, + existing_nullable=existing_nullable, + ) + + +def upgrade() -> None: + for table_name in ( + "new_llm_configs", + "vision_llm_configs", + "image_generation_configs", + ): + if _table_exists(table_name): + op.drop_table(table_name) + + _drop_column_if_exists("searchspaces", "agent_llm_id") + _drop_column_if_exists("searchspaces", "image_generation_config_id") + _drop_column_if_exists("searchspaces", "vision_llm_config_id") + + _rename_column_if_exists( + "image_generations", + "image_generation_config_id", + "image_gen_model_id", + existing_type=sa.Integer(), + ) + + op.execute("DROP TYPE IF EXISTS litellmprovider") + op.execute("DROP TYPE IF EXISTS imagegenprovider") + op.execute("DROP TYPE IF EXISTS visionprovider") + + +def downgrade() -> None: + bind = op.get_bind() + litellm_provider.create(bind, checkfirst=True) + image_gen_provider.create(bind, checkfirst=True) + vision_provider.create(bind, checkfirst=True) + + _rename_column_if_exists( + "image_generations", + "image_gen_model_id", + "image_generation_config_id", + existing_type=sa.Integer(), + ) + + if _table_exists("searchspaces"): + if not _column_exists("searchspaces", "agent_llm_id"): + op.add_column( + "searchspaces", + sa.Column("agent_llm_id", sa.Integer(), nullable=True), + ) + if not _column_exists("searchspaces", "image_generation_config_id"): + op.add_column( + "searchspaces", + sa.Column("image_generation_config_id", sa.Integer(), nullable=True), + ) + if not _column_exists("searchspaces", "vision_llm_config_id"): + op.add_column( + "searchspaces", + sa.Column("vision_llm_config_id", sa.Integer(), nullable=True), + ) + + if not _table_exists("image_generation_configs"): + op.create_table( + "image_generation_configs", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("created_at", sa.DateTime(timezone=True), nullable=False), + sa.Column("name", sa.String(length=100), nullable=False), + sa.Column("description", sa.String(length=500), nullable=True), + sa.Column("provider", image_gen_provider, nullable=False), + sa.Column("custom_provider", sa.String(length=100), nullable=True), + sa.Column("model_name", sa.String(length=100), nullable=False), + sa.Column("api_key", sa.String(), nullable=False), + sa.Column("api_base", sa.String(length=500), nullable=True), + sa.Column("api_version", sa.String(length=50), nullable=True), + sa.Column("litellm_params", sa.JSON(), nullable=True), + sa.Column("search_space_id", sa.Integer(), nullable=False), + sa.Column("user_id", postgresql.UUID(as_uuid=True), nullable=False), + sa.ForeignKeyConstraint( + ["search_space_id"], ["searchspaces.id"], ondelete="CASCADE" + ), + sa.ForeignKeyConstraint(["user_id"], ["user.id"], ondelete="CASCADE"), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index( + op.f("ix_image_generation_configs_name"), + "image_generation_configs", + ["name"], + unique=False, + ) + + if not _table_exists("vision_llm_configs"): + op.create_table( + "vision_llm_configs", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("created_at", sa.DateTime(timezone=True), nullable=False), + sa.Column("name", sa.String(length=100), nullable=False), + sa.Column("description", sa.String(length=500), nullable=True), + sa.Column("provider", vision_provider, nullable=False), + sa.Column("custom_provider", sa.String(length=100), nullable=True), + sa.Column("model_name", sa.String(length=100), nullable=False), + sa.Column("api_key", sa.String(), nullable=False), + sa.Column("api_base", sa.String(length=500), nullable=True), + sa.Column("api_version", sa.String(length=50), nullable=True), + sa.Column("litellm_params", sa.JSON(), nullable=True), + sa.Column("search_space_id", sa.Integer(), nullable=False), + sa.Column("user_id", postgresql.UUID(as_uuid=True), nullable=False), + sa.ForeignKeyConstraint( + ["search_space_id"], ["searchspaces.id"], ondelete="CASCADE" + ), + sa.ForeignKeyConstraint(["user_id"], ["user.id"], ondelete="CASCADE"), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index( + op.f("ix_vision_llm_configs_name"), + "vision_llm_configs", + ["name"], + unique=False, + ) + + if not _table_exists("new_llm_configs"): + op.create_table( + "new_llm_configs", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("created_at", sa.DateTime(timezone=True), nullable=False), + sa.Column("name", sa.String(length=100), nullable=False), + sa.Column("description", sa.String(length=500), nullable=True), + sa.Column("provider", litellm_provider, nullable=False), + sa.Column("custom_provider", sa.String(length=100), nullable=True), + sa.Column("model_name", sa.String(length=100), nullable=False), + sa.Column("api_key", sa.String(), nullable=False), + sa.Column("api_base", sa.String(length=500), nullable=True), + sa.Column("litellm_params", sa.JSON(), nullable=True), + sa.Column("system_instructions", sa.Text(), nullable=False), + sa.Column("use_default_system_instructions", sa.Boolean(), nullable=False), + sa.Column("citations_enabled", sa.Boolean(), nullable=False), + sa.Column("search_space_id", sa.Integer(), nullable=False), + sa.Column("user_id", postgresql.UUID(as_uuid=True), nullable=False), + sa.ForeignKeyConstraint( + ["search_space_id"], ["searchspaces.id"], ondelete="CASCADE" + ), + sa.ForeignKeyConstraint(["user_id"], ["user.id"], ondelete="CASCADE"), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index( + op.f("ix_new_llm_configs_name"), + "new_llm_configs", + ["name"], + unique=False, + ) diff --git a/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/deliverables/tools/generate_image.py b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/deliverables/tools/generate_image.py index 505831faa..d847e021a 100644 --- a/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/deliverables/tools/generate_image.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/deliverables/tools/generate_image.py @@ -215,7 +215,7 @@ def create_generate_image_tool( prompt=prompt, model=getattr(response, "_hidden_params", {}).get("model"), n=n, - image_generation_config_id=config_id, + image_gen_model_id=config_id, response_data=response_dict, search_space_id=search_space_id, access_token=access_token, diff --git a/surfsense_backend/app/agents/chat/runtime/llm_config.py b/surfsense_backend/app/agents/chat/runtime/llm_config.py index efc188df8..e00d16ee8 100644 --- a/surfsense_backend/app/agents/chat/runtime/llm_config.py +++ b/surfsense_backend/app/agents/chat/runtime/llm_config.py @@ -24,8 +24,6 @@ from langchain_core.messages import AIMessage, BaseMessage from langchain_core.outputs import ChatGenerationChunk, ChatResult from langchain_litellm import ChatLiteLLM from litellm import get_model_info -from sqlalchemy import select -from sqlalchemy.ext.asyncio import AsyncSession from app.agents.chat.runtime.prompt_caching import ( apply_litellm_prompt_caching, @@ -34,7 +32,6 @@ from app.services.llm_router_service import ( AUTO_MODE_ID, ChatLiteLLMRouter, _sanitize_content, - is_auto_mode, ) @@ -130,7 +127,7 @@ class AgentConfig: """ Complete configuration for the SurfSense agent. - This combines LLM settings with prompt configuration from NewLLMConfig. + This combines resolved model settings with prompt configuration. Supports Auto mode metadata (ID 0). Runtime callers must resolve Auto to a concrete global or BYOK model before constructing ChatLiteLLM. """ @@ -180,7 +177,7 @@ class AgentConfig: use_default_system_instructions=True, citations_enabled=True, config_id=AUTO_MODE_ID, - config_name="Auto (Fastest)", + config_name="Auto", is_auto_mode=True, billing_tier="free", is_premium=False, @@ -191,57 +188,12 @@ class AgentConfig: supports_image_input=True, ) - @classmethod - def from_new_llm_config(cls, config) -> "AgentConfig": - """Build an AgentConfig from a NewLLMConfig database model.""" - # Lazy import: keeps provider_capabilities (and litellm) out of init order. - from app.services.provider_capabilities import derive_supports_image_input - - provider_value = ( - config.provider.value - if hasattr(config.provider, "value") - else str(config.provider) - ) - litellm_params = config.litellm_params or {} - base_model = ( - litellm_params.get("base_model") - if isinstance(litellm_params, dict) - else None - ) - - return cls( - provider=provider_value, - model_name=config.model_name, - api_key=config.api_key, - api_base=config.api_base, - custom_provider=config.custom_provider, - litellm_params=config.litellm_params, - system_instructions=config.system_instructions, - use_default_system_instructions=config.use_default_system_instructions, - citations_enabled=config.citations_enabled, - config_id=config.id, - config_name=config.name, - is_auto_mode=False, - billing_tier="free", - is_premium=False, - anonymous_enabled=False, - quota_reserve_tokens=None, - # BYOK rows have no curated flag; ask LiteLLM (default-allow on - # unknown). The streaming safety net still blocks explicit text-only. - supports_image_input=derive_supports_image_input( - provider=provider_value.lower(), - model_name=config.model_name, - base_model=base_model, - custom_provider=config.custom_provider, - ), - ) - @classmethod def from_yaml_config(cls, yaml_config: dict) -> "AgentConfig": """Build an AgentConfig from a YAML configuration dictionary. - Supports the same prompt fields as NewLLMConfig (system_instructions, - use_default_system_instructions, citations_enabled). + Supports prompt fields such as system_instructions, + use_default_system_instructions, and citations_enabled. """ # Lazy import: keeps provider_capabilities (and litellm) out of init order. from app.services.provider_capabilities import derive_supports_image_input @@ -334,82 +286,6 @@ def load_global_llm_config_by_id(llm_config_id: int) -> dict | None: return load_llm_config_from_yaml(llm_config_id) -async def load_new_llm_config_from_db( - session: AsyncSession, - config_id: int, -) -> "AgentConfig | None": - """Load a NewLLMConfig from the database by ID.""" - from app.db import NewLLMConfig - - try: - result = await session.execute( - select(NewLLMConfig).filter(NewLLMConfig.id == config_id) - ) - config = result.scalars().first() - - if not config: - print(f"Error: NewLLMConfig with id {config_id} not found") - return None - - return AgentConfig.from_new_llm_config(config) - except Exception as e: - print(f"Error loading NewLLMConfig from database: {e}") - return None - - -async def load_agent_llm_config_for_search_space( - session: AsyncSession, - search_space_id: int, -) -> "AgentConfig | None": - """Load the chat model config for a search space via its agent_llm_id. - - Positive id -> DB; negative -> YAML; None -> first global config (-1). - """ - from app.db import SearchSpace - - try: - result = await session.execute( - select(SearchSpace).filter(SearchSpace.id == search_space_id) - ) - search_space = result.scalars().first() - - if not search_space: - print(f"Error: SearchSpace with id {search_space_id} not found") - return None - - config_id = ( - search_space.agent_llm_id if search_space.agent_llm_id is not None else -1 - ) - return await load_agent_config(session, config_id, search_space_id) - except Exception as e: - print(f"Error loading chat model config for search space {search_space_id}: {e}") - return None - - -async def load_agent_config( - session: AsyncSession, - config_id: int, - search_space_id: int | None = None, -) -> "AgentConfig | None": - """Main config loader: id 0 -> Auto mode; negative -> YAML; positive -> DB.""" - if is_auto_mode(config_id): - return AgentConfig.from_auto_mode() - - if config_id < 0: - # In-memory covers static YAML + dynamic OpenRouter configs. - from app.config import config as app_config - - for cfg in app_config.GLOBAL_LLM_CONFIGS: - if cfg.get("id") == config_id: - return AgentConfig.from_yaml_config(cfg) - yaml_config = load_llm_config_from_yaml(config_id) - if yaml_config: - return AgentConfig.from_yaml_config(yaml_config) - return None - else: - return await load_new_llm_config_from_db(session, config_id) - - def create_chat_litellm_from_config(llm_config: dict) -> ChatLiteLLM | None: """Create a ChatLiteLLM instance from a global LLM config dictionary.""" if llm_config.get("custom_provider"): diff --git a/surfsense_backend/app/automations/services/model_policy.py b/surfsense_backend/app/automations/services/model_policy.py index b160fc78d..e18264246 100644 --- a/surfsense_backend/app/automations/services/model_policy.py +++ b/surfsense_backend/app/automations/services/model_policy.py @@ -2,11 +2,11 @@ Automations run unattended, so every run must be **billable**: it may only use either a premium global model (``billing_tier == "premium"``) or a user-provided -BYOK model (a positive config id pointing at a per-user/per-space DB row). Free +BYOK model (a positive model id pointing at a per-user/per-space DB row). Free global models and Auto mode are blocked, because Auto can dispatch to a free deployment and free models aren't metered in premium credits. -Config id conventions (shared across chat / image / vision): +Model id conventions (shared across chat / image / vision): - ``id == 0`` → Auto mode (``AUTO_MODE_ID`` / ``IMAGE_GEN_AUTO_MODE_ID`` / ``VISION_AUTO_MODE_ID``). Blocked. - ``id < 0`` → global YAML/OpenRouter config. Allowed only if premium. @@ -82,7 +82,7 @@ def get_model_eligibility( The ID-based core shared by both the search-space path (creation/eligibility) and the captured-snapshot path (runtime backstop). Each violation is - ``{"kind", "config_id", "reason"}``. + ``{"kind", "model_id", "reason"}``. """ checks: list[tuple[ModelKind, int | None]] = [ ("chat", chat_model_id), @@ -91,10 +91,10 @@ def get_model_eligibility( ] violations: list[dict] = [] - for kind, config_id in checks: - allowed, reason = _classify(kind, config_id) + for kind, model_id in checks: + allowed, reason = _classify(kind, model_id) if not allowed: - violations.append({"kind": kind, "model_id": config_id, "reason": reason}) + violations.append({"kind": kind, "model_id": model_id, "reason": reason}) return {"allowed": not violations, "violations": violations} diff --git a/surfsense_backend/app/config/__init__.py b/surfsense_backend/app/config/__init__.py index d690c1d7d..8c9662aa8 100644 --- a/surfsense_backend/app/config/__init__.py +++ b/surfsense_backend/app/config/__init__.py @@ -119,7 +119,7 @@ def load_global_llm_configs(): else: seen_slugs[slug] = cfg.get("id", 0) - # Stamp Auto (Fastest) ranking metadata. YAML configs are always + # Stamp Auto ranking metadata. YAML configs are always # Tier A — operator-curated, locked first when premium-eligible. # The OpenRouter refresh tick later re-stamps health for any cfg # whose provider == "openrouter" via _enrich_health. @@ -210,42 +210,6 @@ def load_global_image_gen_configs(): return [] -def load_global_vision_llm_configs(): - data = _global_config_data() - if not data: - return [] - - try: - configs = copy.deepcopy(data.get("global_vision_llm_configs", []) or []) - for cfg in configs: - if isinstance(cfg, dict): - cfg.setdefault("billing_tier", "free") - return configs - except Exception as e: - print(f"Warning: Failed to load global vision LLM configs: {e}") - return [] - - -def load_vision_llm_router_settings(): - default_settings = { - "routing_strategy": "usage-based-routing", - "num_retries": 3, - "allowed_fails": 3, - "cooldown_time": 60, - } - - data = _global_config_data() - if not data: - return default_settings - - try: - settings = data.get("vision_llm_router_settings", {}) - return {**default_settings, **settings} - except Exception as e: - print(f"Warning: Failed to load vision LLM router settings: {e}") - return default_settings - - def load_image_gen_router_settings(): """ Load router settings for image generation Auto mode from YAML file. @@ -482,12 +446,6 @@ def initialize_image_gen_router(): print(f"Warning: Failed to initialize Image Generation Router: {e}") -def initialize_vision_llm_router(): - # Retired: vision Auto now uses shared capability-filtered model selection - # over GLOBAL/BYOK chat models with supports_image_input=true. - return - - class Config: # Check if ffmpeg is installed if not is_ffmpeg_installed(): @@ -869,12 +827,6 @@ class Config: # Router settings for Image Generation Auto mode IMAGE_GEN_ROUTER_SETTINGS = load_image_gen_router_settings() - # Global Vision LLM Configurations (optional) - GLOBAL_VISION_LLM_CONFIGS = load_global_vision_llm_configs() - - # Router settings for Vision LLM Auto mode - VISION_LLM_ROUTER_SETTINGS = load_vision_llm_router_settings() - # Virtual GLOBAL connection/model catalog. This is server-only metadata # derived from global_llm_config.yaml; GLOBAL keys are not stored in DB. from app.services.global_model_catalog import ( diff --git a/surfsense_backend/app/config/global_llm_config.example.yaml b/surfsense_backend/app/config/global_llm_config.example.yaml index 06676511f..c5b65fee0 100644 --- a/surfsense_backend/app/config/global_llm_config.example.yaml +++ b/surfsense_backend/app/config/global_llm_config.example.yaml @@ -433,87 +433,11 @@ global_image_generation_configs: # rpm: 30 # litellm_params: {} -# ============================================================================= -# Vision LLM Configuration -# ============================================================================= -# These configurations power the vision autocomplete feature (screenshot analysis). -# Only vision-capable models should be used here (e.g. GPT-4o, Gemini Pro, Claude 3). -# Supported providers: OpenAI, Anthropic, Google, Azure OpenAI, Vertex AI, Bedrock, -# xAI, OpenRouter, Ollama, Groq, Together AI, Fireworks AI, DeepSeek, Mistral, Custom -# -# Auto mode (ID 0) uses LiteLLM Router for load balancing across all vision configs. - -# Router Settings for Vision LLM Auto Mode -vision_llm_router_settings: - routing_strategy: "usage-based-routing" - num_retries: 3 - allowed_fails: 3 - cooldown_time: 60 - -global_vision_llm_configs: - # Example: OpenAI GPT-4o (recommended for vision) - - id: -1001 - name: "Global GPT-4o Vision" - description: "OpenAI's GPT-4o with strong vision capabilities" - litellm_provider: "openai" - model_name: "gpt-4o" - api_key: "sk-your-openai-api-key-here" - api_base: "https://api.openai.com/v1" - rpm: 500 - tpm: 100000 - litellm_params: - temperature: 0.3 - max_tokens: 1000 - - # Example: Google Gemini 2.0 Flash - - id: -1002 - name: "Global Gemini 2.0 Flash" - description: "Google's fast vision model with large context" - litellm_provider: "gemini" - model_name: "gemini-2.0-flash" - api_key: "your-google-ai-api-key-here" - api_base: "https://generativelanguage.googleapis.com/v1beta" - rpm: 1000 - tpm: 200000 - litellm_params: - temperature: 0.3 - max_tokens: 1000 - - # Example: Anthropic Claude 3.5 Sonnet - - id: -1003 - name: "Global Claude 3.5 Sonnet Vision" - description: "Anthropic's Claude 3.5 Sonnet with vision support" - litellm_provider: "anthropic" - model_name: "claude-3-5-sonnet-20241022" - api_key: "sk-ant-your-anthropic-api-key-here" - api_base: "https://api.anthropic.com/v1" - rpm: 1000 - tpm: 100000 - litellm_params: - temperature: 0.3 - max_tokens: 1000 - - # Example: Azure OpenAI GPT-4o - # - id: -1004 - # name: "Global Azure GPT-4o Vision" - # description: "Azure-hosted GPT-4o for vision analysis" - # litellm_provider: "azure" - # model_name: "azure/gpt-4o-deployment" - # api_key: "your-azure-api-key-here" - # api_base: "https://your-resource.openai.azure.com" - # api_version: "2024-02-15-preview" - # rpm: 500 - # tpm: 100000 - # litellm_params: - # temperature: 0.3 - # max_tokens: 1000 - # base_model: "gpt-4o" - # Notes: # - ID 0 is reserved for "Auto" mode - uses LiteLLM Router for load balancing # - Use negative IDs to distinguish global models from BYOK/local DB models -# - IDs must be unique across chat, vision, and image generation configs -# - Suggested static ranges: chat -1..-999, vision -1001..-1999, image -2001..-2999 +# - IDs must be unique across chat and image generation configs +# - Suggested static ranges: chat -1..-999, image -2001..-2999 # - The 'api_key' field will not be exposed to users via API # - system_instructions: Custom prompt or empty string to use defaults # - use_default_system_instructions: true = use SURFSENSE_SYSTEM_INSTRUCTIONS when system_instructions is empty diff --git a/surfsense_backend/app/db.py b/surfsense_backend/app/db.py index 728031fa0..38d0ffe33 100644 --- a/surfsense_backend/app/db.py +++ b/surfsense_backend/app/db.py @@ -198,81 +198,6 @@ class DocumentStatus: return None -class LiteLLMProvider(StrEnum): - """ - Enum for LLM providers supported by LiteLLM. - """ - - OPENAI = "OPENAI" - ANTHROPIC = "ANTHROPIC" - GOOGLE = "GOOGLE" - AZURE_OPENAI = "AZURE_OPENAI" - BEDROCK = "BEDROCK" - VERTEX_AI = "VERTEX_AI" - GROQ = "GROQ" - COHERE = "COHERE" - MISTRAL = "MISTRAL" - DEEPSEEK = "DEEPSEEK" - XAI = "XAI" - OPENROUTER = "OPENROUTER" - TOGETHER_AI = "TOGETHER_AI" - FIREWORKS_AI = "FIREWORKS_AI" - REPLICATE = "REPLICATE" - PERPLEXITY = "PERPLEXITY" - OLLAMA = "OLLAMA" - ALIBABA_QWEN = "ALIBABA_QWEN" - MOONSHOT = "MOONSHOT" - ZHIPU = "ZHIPU" - ANYSCALE = "ANYSCALE" - DEEPINFRA = "DEEPINFRA" - CEREBRAS = "CEREBRAS" - SAMBANOVA = "SAMBANOVA" - AI21 = "AI21" - CLOUDFLARE = "CLOUDFLARE" - DATABRICKS = "DATABRICKS" - COMETAPI = "COMETAPI" - HUGGINGFACE = "HUGGINGFACE" - GITHUB_MODELS = "GITHUB_MODELS" - MINIMAX = "MINIMAX" - CUSTOM = "CUSTOM" - - -class ImageGenProvider(StrEnum): - """ - Enum for image generation providers supported by LiteLLM. - This is a subset of LLM providers — only those that support image generation. - See: https://docs.litellm.ai/docs/image_generation#supported-providers - """ - - OPENAI = "OPENAI" - AZURE_OPENAI = "AZURE_OPENAI" - GOOGLE = "GOOGLE" # Google AI Studio - VERTEX_AI = "VERTEX_AI" - BEDROCK = "BEDROCK" # AWS Bedrock - RECRAFT = "RECRAFT" - OPENROUTER = "OPENROUTER" - XINFERENCE = "XINFERENCE" - NSCALE = "NSCALE" - - -class VisionProvider(StrEnum): - OPENAI = "OPENAI" - ANTHROPIC = "ANTHROPIC" - GOOGLE = "GOOGLE" - AZURE_OPENAI = "AZURE_OPENAI" - VERTEX_AI = "VERTEX_AI" - BEDROCK = "BEDROCK" - XAI = "XAI" - OPENROUTER = "OPENROUTER" - OLLAMA = "OLLAMA" - GROQ = "GROQ" - TOGETHER_AI = "TOGETHER_AI" - FIREWORKS_AI = "FIREWORKS_AI" - DEEPSEEK = "DEEPSEEK" - MISTRAL = "MISTRAL" - CUSTOM = "CUSTOM" - - class ConnectionScope(StrEnum): GLOBAL = "GLOBAL" SEARCH_SPACE = "SEARCH_SPACE" @@ -710,11 +635,11 @@ class NewChatThread(BaseModel, TimestampMixin): default=False, server_default="false", ) - # Auto (Fastest) model pin for this thread: concrete resolved global LLM + # Auto model pin for this thread: concrete resolved global LLM # config id. NULL means no pin; Auto will resolve on the next turn. # Single-writer invariant: only app.services.auto_model_pin_service sets # or clears this column (plus bulk clears when a search space's - # agent_llm_id changes). Unindexed: all reads are by primary key. + # chat_model_id changes). Unindexed: all reads are by primary key. pinned_llm_config_id = Column(Integer, nullable=True) # Surface metadata for first-party SurfSense and external chat threads. @@ -1686,75 +1611,6 @@ class Model(BaseModel, TimestampMixin): ) -class ImageGenerationConfig(BaseModel, TimestampMixin): - """ - Dedicated configuration table for image generation models. - - Separate from NewLLMConfig because image generation models don't need - system_instructions, citations_enabled, or use_default_system_instructions. - They only need provider credentials and model parameters. - """ - - __tablename__ = "image_generation_configs" - - name = Column(String(100), nullable=False, index=True) - description = Column(String(500), nullable=True) - - # Provider & model (uses ImageGenProvider, NOT LiteLLMProvider) - provider = Column(SQLAlchemyEnum(ImageGenProvider), nullable=False) - custom_provider = Column(String(100), nullable=True) - model_name = Column(String(100), nullable=False) - - # Credentials - api_key = Column(String, nullable=False) - api_base = Column(String(500), nullable=True) - api_version = Column(String(50), nullable=True) # Azure-specific - - # Additional litellm parameters - litellm_params = Column(JSON, nullable=True, default={}) - - # Relationships - search_space_id = Column( - Integer, ForeignKey("searchspaces.id", ondelete="CASCADE"), nullable=False - ) - search_space = relationship( - "SearchSpace", back_populates="image_generation_configs" - ) - - # User who created this config - user_id = Column( - UUID(as_uuid=True), ForeignKey("user.id", ondelete="CASCADE"), nullable=False - ) - user = relationship("User", back_populates="image_generation_configs") - - -class VisionLLMConfig(BaseModel, TimestampMixin): - __tablename__ = "vision_llm_configs" - - name = Column(String(100), nullable=False, index=True) - description = Column(String(500), nullable=True) - - provider = Column(SQLAlchemyEnum(VisionProvider), nullable=False) - custom_provider = Column(String(100), nullable=True) - model_name = Column(String(100), nullable=False) - - api_key = Column(String, nullable=False) - api_base = Column(String(500), nullable=True) - api_version = Column(String(50), nullable=True) - - litellm_params = Column(JSON, nullable=True, default={}) - - search_space_id = Column( - Integer, ForeignKey("searchspaces.id", ondelete="CASCADE"), nullable=False - ) - search_space = relationship("SearchSpace", back_populates="vision_llm_configs") - - user_id = Column( - UUID(as_uuid=True), ForeignKey("user.id", ondelete="CASCADE"), nullable=False - ) - user = relationship("User", back_populates="vision_llm_configs") - - class ImageGeneration(BaseModel, TimestampMixin): """ Stores image generation requests and results using litellm.aimage_generation(). @@ -1786,10 +1642,9 @@ class ImageGeneration(BaseModel, TimestampMixin): style = Column(String(50), nullable=True) # Model-specific style parameter response_format = Column(String(50), nullable=True) # "url" or "b64_json" - # Image generation config reference - # 0 = Auto mode (router), negative IDs = global configs from YAML, - # positive IDs = ImageGenerationConfig records in DB - image_generation_config_id = Column(Integer, nullable=True) + # Image generation model provenance. + # 0 = Auto mode, negative IDs = GLOBAL models, positive IDs = Model records. + image_gen_model_id = Column(Integer, nullable=True) # Response data (full litellm response as JSONB) — present on success response_data = Column(JSONB, nullable=True) @@ -1831,23 +1686,7 @@ class SearchSpace(BaseModel, TimestampMixin): shared_memory_md = Column(Text, nullable=True, server_default="") - # Search space-level LLM preferences (shared by all members) - # Note: ID values: - # - 0: Auto mode (uses LiteLLM Router for load balancing) - default for new search spaces - # - Negative IDs: Global configs from YAML - # - Positive IDs: Custom configs from DB (NewLLMConfig table) - agent_llm_id = Column( - Integer, nullable=True, default=0 - ) # For chat operations, defaults to Auto mode - image_generation_config_id = Column( - Integer, nullable=True, default=0 - ) # For image generation, defaults to Auto mode - vision_llm_config_id = Column( - Integer, nullable=True, default=0 - ) # For vision/screenshot analysis, defaults to Auto mode - - # New connection/model role bindings. These supersede the legacy config - # columns above without removing them in this PR. + # Connection/model role bindings. # Note: ID values preserve the existing convention: # - 0: Auto mode # - Negative IDs: Global virtual models from global_llm_config.yaml @@ -1931,24 +1770,6 @@ class SearchSpace(BaseModel, TimestampMixin): order_by="SearchSourceConnector.id", cascade="all, delete-orphan", ) - new_llm_configs = relationship( - "NewLLMConfig", - back_populates="search_space", - order_by="NewLLMConfig.id", - cascade="all, delete-orphan", - ) - image_generation_configs = relationship( - "ImageGenerationConfig", - back_populates="search_space", - order_by="ImageGenerationConfig.id", - cascade="all, delete-orphan", - ) - vision_llm_configs = relationship( - "VisionLLMConfig", - back_populates="search_space", - order_by="VisionLLMConfig.id", - cascade="all, delete-orphan", - ) connections = relationship( "Connection", back_populates="search_space", @@ -2057,64 +1878,6 @@ class SearchSourceConnector(BaseModel, TimestampMixin): documents = relationship("Document", back_populates="connector") -class NewLLMConfig(BaseModel, TimestampMixin): - """ - New LLM configuration table that combines model settings with prompt configuration. - - This table provides: - - LLM model configuration (provider, model_name, api_key, etc.) - - Configurable system instructions (defaults to SURFSENSE_SYSTEM_INSTRUCTIONS) - - Citation toggle (enable/disable citation instructions) - - Note: Tools instructions are built by get_tools_instructions(thread_visibility) (personal vs shared memory). - """ - - __tablename__ = "new_llm_configs" - - name = Column(String(100), nullable=False, index=True) - description = Column(String(500), nullable=True) - - # === LLM Model Configuration (from original LLMConfig, excluding 'language') === - # Provider from the enum - provider = Column(SQLAlchemyEnum(LiteLLMProvider), nullable=False) - # Custom provider name when provider is CUSTOM - custom_provider = Column(String(100), nullable=True) - # Just the model name without provider prefix - model_name = Column(String(100), nullable=False) - # API Key should be encrypted before storing - api_key = Column(String, nullable=False) - api_base = Column(String(500), nullable=True) - # For any other parameters that litellm supports - litellm_params = Column(JSON, nullable=True, default={}) - - # === Prompt Configuration === - # Configurable system instructions (defaults to SURFSENSE_SYSTEM_INSTRUCTIONS) - # Users can customize this from the UI - system_instructions = Column( - Text, - nullable=False, - default="", # Empty string means use default SURFSENSE_SYSTEM_INSTRUCTIONS - ) - # Whether to use the default system instructions when system_instructions is empty - use_default_system_instructions = Column(Boolean, nullable=False, default=True) - - # Citation toggle - when enabled, SURFSENSE_CITATION_INSTRUCTIONS is injected - # When disabled, an anti-citation prompt is injected instead - citations_enabled = Column(Boolean, nullable=False, default=True) - - # === Relationships === - search_space_id = Column( - Integer, ForeignKey("searchspaces.id", ondelete="CASCADE"), nullable=False - ) - search_space = relationship("SearchSpace", back_populates="new_llm_configs") - - # User who created this config - user_id = Column( - UUID(as_uuid=True), ForeignKey("user.id", ondelete="CASCADE"), nullable=False - ) - user = relationship("User", back_populates="new_llm_configs") - - class Log(BaseModel, TimestampMixin): __tablename__ = "logs" @@ -2481,25 +2244,6 @@ if config.AUTH_TYPE == "GOOGLE": passive_deletes=True, ) - # LLM configs created by this user - new_llm_configs = relationship( - "NewLLMConfig", - back_populates="user", - passive_deletes=True, - ) - - # Image generation configs created by this user - image_generation_configs = relationship( - "ImageGenerationConfig", - back_populates="user", - passive_deletes=True, - ) - - vision_llm_configs = relationship( - "VisionLLMConfig", - back_populates="user", - passive_deletes=True, - ) connections = relationship( "Connection", back_populates="user", @@ -2632,25 +2376,6 @@ else: passive_deletes=True, ) - # LLM configs created by this user - new_llm_configs = relationship( - "NewLLMConfig", - back_populates="user", - passive_deletes=True, - ) - - # Image generation configs created by this user - image_generation_configs = relationship( - "ImageGenerationConfig", - back_populates="user", - passive_deletes=True, - ) - - vision_llm_configs = relationship( - "VisionLLMConfig", - back_populates="user", - passive_deletes=True, - ) connections = relationship( "Connection", back_populates="user", diff --git a/surfsense_backend/app/prompts/default_system_instructions.py b/surfsense_backend/app/prompts/default_system_instructions.py index fd0a8e186..b968fc1f0 100644 --- a/surfsense_backend/app/prompts/default_system_instructions.py +++ b/surfsense_backend/app/prompts/default_system_instructions.py @@ -82,7 +82,7 @@ def build_configurable_system_prompt( *, model_name: str | None = None, ) -> str: - """Build a configurable SurfSense system prompt (NewLLMConfig path). + """Build a configurable SurfSense system prompt. See :func:`app.prompts.system_prompt_composer.composer.compose_system_prompt` for full parameter docs. @@ -104,7 +104,7 @@ def build_configurable_system_prompt( def get_default_system_instructions() -> str: """Return the default ```` block (no tools / citations). - Useful for populating the UI when seeding ``NewLLMConfig.system_instructions``. + Useful for populating the UI when editing custom system instructions. The output reflects the current fragment tree, not a baked-in constant. """ resolved_today = datetime.now(UTC).date().isoformat() diff --git a/surfsense_backend/app/prompts/system_prompt_composer/composer.py b/surfsense_backend/app/prompts/system_prompt_composer/composer.py index 3849af313..c639d4aa0 100644 --- a/surfsense_backend/app/prompts/system_prompt_composer/composer.py +++ b/surfsense_backend/app/prompts/system_prompt_composer/composer.py @@ -348,8 +348,7 @@ def compose_system_prompt( mcp_connector_tools: ``{server_name: [tool_names...]}`` to inject an explicit MCP routing block. custom_system_instructions: Free-form instructions that override - the default ```` block (legacy support - for ``NewLLMConfig.system_instructions``). + the default ```` block. use_default_system_instructions: When ``custom_system_instructions`` is empty/None, fall back to defaults (legacy semantics). citations_enabled: Include ``citations_on.md`` (true) or diff --git a/surfsense_backend/app/routes/__init__.py b/surfsense_backend/app/routes/__init__.py index f9f6b3d28..2b997cef5 100644 --- a/surfsense_backend/app/routes/__init__.py +++ b/surfsense_backend/app/routes/__init__.py @@ -47,7 +47,6 @@ from .model_connections_routes import router as model_connections_router from .memory_routes import router as memory_router from .model_list_routes import router as model_list_router from .new_chat_routes import router as new_chat_router -from .new_llm_config_routes import router as new_llm_config_router from .notes_routes import router as notes_router from .notion_add_connector_route import router as notion_add_connector_router from .obsidian_plugin_routes import router as obsidian_plugin_router @@ -64,7 +63,6 @@ from .stripe_routes import router as stripe_router from .team_memory_routes import router as team_memory_router from .teams_add_connector_route import router as teams_add_connector_router from .video_presentations_routes import router as video_presentations_router -from .vision_llm_routes import router as vision_llm_router from .youtube_routes import router as youtube_router router = APIRouter() @@ -99,7 +97,6 @@ router.include_router( ) # Video presentation status and streaming router.include_router(reports_router) # Report CRUD and multi-format export router.include_router(image_generation_router) # Image generation via litellm -router.include_router(vision_llm_router) # Vision LLM configs for screenshot analysis router.include_router(search_source_connectors_router) router.include_router(google_calendar_add_connector_router) router.include_router(google_gmail_add_connector_router) @@ -117,7 +114,6 @@ router.include_router(jira_add_connector_router) router.include_router(confluence_add_connector_router) router.include_router(clickup_add_connector_router) router.include_router(dropbox_add_connector_router) -router.include_router(new_llm_config_router) # LLM configs with prompt configuration router.include_router(model_connections_router) # Connection-centric model catalog router.include_router(model_list_router) # Dynamic model catalogue from OpenRouter router.include_router(logs_router) diff --git a/surfsense_backend/app/routes/image_generation_routes.py b/surfsense_backend/app/routes/image_generation_routes.py index 29a2b58bc..7e95d4dba 100644 --- a/surfsense_backend/app/routes/image_generation_routes.py +++ b/surfsense_backend/app/routes/image_generation_routes.py @@ -1,7 +1,5 @@ """ Image Generation routes: -- CRUD for ImageGenerationConfig (user-created image model configs) -- Global image gen configs endpoint (from YAML) - Image generation execution (calls litellm.aimage_generation()) - CRUD for ImageGeneration records (results) - Image serving endpoint (serves b64_json images from DB, protected by signed tokens) @@ -21,7 +19,6 @@ from sqlalchemy.orm import selectinload from app.config import config from app.db import ( ImageGeneration, - ImageGenerationConfig, Model, Permission, SearchSpace, @@ -30,14 +27,14 @@ from app.db import ( get_async_session, ) from app.schemas import ( - GlobalImageGenConfigRead, - ImageGenerationConfigCreate, - ImageGenerationConfigRead, - ImageGenerationConfigUpdate, ImageGenerationCreate, ImageGenerationListRead, ImageGenerationRead, ) +from app.services.auto_model_pin_service import ( + auto_model_candidates, + choose_auto_model_candidate, +) from app.services.billable_calls import ( DEFAULT_IMAGE_RESERVE_MICROS, QuotaInsufficientError, @@ -47,12 +44,8 @@ from app.services.image_gen_router_service import ( IMAGE_GEN_AUTO_MODE_ID, is_image_gen_auto_mode, ) -from app.services.auto_model_pin_service import ( - auto_model_candidates, - choose_auto_model_candidate, -) -from app.services.model_resolver import to_litellm from app.services.model_capabilities import has_capability +from app.services.model_resolver import to_litellm from app.users import current_active_user from app.utils.rbac import check_permission from app.utils.signed_image_urls import verify_image_token @@ -131,14 +124,14 @@ async def _execute_image_generation( Call litellm.aimage_generation() with the appropriate config. Resolution order: - 1. Explicit image_generation_config_id on the request - 2. Search space's image_generation_config_id preference + 1. Explicit image_gen_model_id on the request + 2. Search space's image_gen_model_id preference 3. Falls back to Auto mode if available """ - config_id = image_gen.image_generation_config_id + config_id = image_gen.image_gen_model_id if config_id is None: config_id = search_space.image_gen_model_id or IMAGE_GEN_AUTO_MODE_ID - image_gen.image_generation_config_id = config_id + image_gen.image_gen_model_id = config_id # Build kwargs gen_kwargs = {} @@ -163,7 +156,7 @@ async def _execute_image_generation( if not candidates: raise ValueError("No image-generation models are available for Auto mode") config_id = int(choose_auto_model_candidate(candidates, search_space.id)["id"]) - image_gen.image_generation_config_id = config_id + image_gen.image_gen_model_id = config_id if config_id < 0: global_model = _get_global_model(config_id) @@ -228,266 +221,6 @@ async def _execute_image_generation( image_gen.model = hidden["model"] -# ============================================================================= -# Global Image Generation Configs (from YAML) -# ============================================================================= - - -@router.get( - "/global-image-generation-configs", - response_model=list[GlobalImageGenConfigRead], -) -async def get_global_image_gen_configs( - user: User = Depends(current_active_user), -): - """Get all global image generation configs. API keys are hidden.""" - try: - global_configs = config.GLOBAL_IMAGE_GEN_CONFIGS - safe_configs = [] - - if global_configs and len(global_configs) > 0: - safe_configs.append( - { - "id": 0, - "name": "Auto (Fastest)", - "description": "Automatically routes across available image generation providers.", - "provider": "AUTO", - "custom_provider": None, - "model_name": "auto", - "api_base": None, - "api_version": None, - "litellm_params": {}, - "is_global": True, - "is_auto_mode": True, - # Auto mode currently treated as free until per-deployment - # billing-tier surfacing lands (see _resolve_billing_for_image_gen). - "billing_tier": "free", - "is_premium": False, - } - ) - - for cfg in global_configs: - billing_tier = str(cfg.get("billing_tier", "free")).lower() - safe_configs.append( - { - "id": cfg.get("id"), - "name": cfg.get("name"), - "description": cfg.get("description"), - "provider": cfg.get("provider") or cfg.get("litellm_provider"), - "custom_provider": cfg.get("custom_provider"), - "model_name": cfg.get("model_name"), - "api_base": cfg.get("api_base") or None, - "api_version": cfg.get("api_version") or None, - "litellm_params": cfg.get("litellm_params", {}), - "is_global": True, - "billing_tier": billing_tier, - # Mirror chat (``new_llm_config_routes``) so the new-chat - # selector's premium badge logic keys off the same - # field across chat / image / vision tabs. - "is_premium": billing_tier == "premium", - "quota_reserve_micros": cfg.get("quota_reserve_micros"), - } - ) - - return safe_configs - except Exception as e: - logger.exception("Failed to fetch global image generation configs") - raise HTTPException( - status_code=500, detail=f"Failed to fetch configs: {e!s}" - ) from e - - -# ============================================================================= -# ImageGenerationConfig CRUD -# ============================================================================= - - -@router.post("/image-generation-configs", response_model=ImageGenerationConfigRead) -async def create_image_gen_config( - config_data: ImageGenerationConfigCreate, - session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), -): - """Create a new image generation config for a search space.""" - try: - await check_permission( - session, - user, - config_data.search_space_id, - Permission.IMAGE_GENERATIONS_CREATE.value, - "You don't have permission to create image generation configs in this search space", - ) - - db_config = ImageGenerationConfig(**config_data.model_dump(), user_id=user.id) - session.add(db_config) - await session.commit() - await session.refresh(db_config) - return db_config - - except HTTPException: - raise - except Exception as e: - await session.rollback() - logger.exception("Failed to create ImageGenerationConfig") - raise HTTPException( - status_code=500, detail=f"Failed to create config: {e!s}" - ) from e - - -@router.get("/image-generation-configs", response_model=list[ImageGenerationConfigRead]) -async def list_image_gen_configs( - search_space_id: int, - skip: int = 0, - limit: int = 100, - session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), -): - """List image generation configs for a search space.""" - try: - await check_permission( - session, - user, - search_space_id, - Permission.IMAGE_GENERATIONS_READ.value, - "You don't have permission to view image generation configs in this search space", - ) - - result = await session.execute( - select(ImageGenerationConfig) - .filter(ImageGenerationConfig.search_space_id == search_space_id) - .order_by(ImageGenerationConfig.created_at.desc()) - .offset(skip) - .limit(limit) - ) - return result.scalars().all() - - except HTTPException: - raise - except Exception as e: - logger.exception("Failed to list ImageGenerationConfigs") - raise HTTPException( - status_code=500, detail=f"Failed to fetch configs: {e!s}" - ) from e - - -@router.get( - "/image-generation-configs/{config_id}", response_model=ImageGenerationConfigRead -) -async def get_image_gen_config( - config_id: int, - session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), -): - """Get a specific image generation config by ID.""" - try: - result = await session.execute( - select(ImageGenerationConfig).filter(ImageGenerationConfig.id == config_id) - ) - db_config = result.scalars().first() - if not db_config: - raise HTTPException(status_code=404, detail="Config not found") - - await check_permission( - session, - user, - db_config.search_space_id, - Permission.IMAGE_GENERATIONS_READ.value, - "You don't have permission to view image generation configs in this search space", - ) - return db_config - - except HTTPException: - raise - except Exception as e: - logger.exception("Failed to get ImageGenerationConfig") - raise HTTPException( - status_code=500, detail=f"Failed to fetch config: {e!s}" - ) from e - - -@router.put( - "/image-generation-configs/{config_id}", response_model=ImageGenerationConfigRead -) -async def update_image_gen_config( - config_id: int, - update_data: ImageGenerationConfigUpdate, - session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), -): - """Update an existing image generation config.""" - try: - result = await session.execute( - select(ImageGenerationConfig).filter(ImageGenerationConfig.id == config_id) - ) - db_config = result.scalars().first() - if not db_config: - raise HTTPException(status_code=404, detail="Config not found") - - await check_permission( - session, - user, - db_config.search_space_id, - Permission.IMAGE_GENERATIONS_CREATE.value, - "You don't have permission to update image generation configs in this search space", - ) - - for key, value in update_data.model_dump(exclude_unset=True).items(): - setattr(db_config, key, value) - - await session.commit() - await session.refresh(db_config) - return db_config - - except HTTPException: - raise - except Exception as e: - await session.rollback() - logger.exception("Failed to update ImageGenerationConfig") - raise HTTPException( - status_code=500, detail=f"Failed to update config: {e!s}" - ) from e - - -@router.delete("/image-generation-configs/{config_id}", response_model=dict) -async def delete_image_gen_config( - config_id: int, - session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), -): - """Delete an image generation config.""" - try: - result = await session.execute( - select(ImageGenerationConfig).filter(ImageGenerationConfig.id == config_id) - ) - db_config = result.scalars().first() - if not db_config: - raise HTTPException(status_code=404, detail="Config not found") - - await check_permission( - session, - user, - db_config.search_space_id, - Permission.IMAGE_GENERATIONS_DELETE.value, - "You don't have permission to delete image generation configs in this search space", - ) - - await session.delete(db_config) - await session.commit() - return { - "message": "Image generation config deleted successfully", - "id": config_id, - } - - except HTTPException: - raise - except Exception as e: - await session.rollback() - logger.exception("Failed to delete ImageGenerationConfig") - raise HTTPException( - status_code=500, detail=f"Failed to delete config: {e!s}" - ) from e - - # ============================================================================= # Image Generation Execution + Results CRUD # ============================================================================= @@ -536,7 +269,7 @@ async def create_image_generation( raise HTTPException(status_code=404, detail="Search space not found") billing_tier, base_model, reserve_micros = await _resolve_billing_for_image_gen( - session, data.image_generation_config_id, search_space + session, data.image_gen_model_id, search_space ) # billable_call runs OUTSIDE the inner try/except so QuotaInsufficientError @@ -562,7 +295,7 @@ async def create_image_generation( size=data.size, style=data.style, response_format=data.response_format, - image_generation_config_id=data.image_generation_config_id, + image_gen_model_id=data.image_gen_model_id, search_space_id=data.search_space_id, created_by_id=user.id, ) diff --git a/surfsense_backend/app/routes/model_connections_routes.py b/surfsense_backend/app/routes/model_connections_routes.py index 1fd2e1e8e..90d246c54 100644 --- a/surfsense_backend/app/routes/model_connections_routes.py +++ b/surfsense_backend/app/routes/model_connections_routes.py @@ -11,6 +11,7 @@ from app.db import ( ConnectionScope, Model, ModelSource, + NewChatThread, Permission, SearchSpace, User, @@ -708,12 +709,26 @@ async def update_model_roles( search_space = await _get_search_space(session, search_space_id) updates = data.model_dump(exclude_unset=True) if "chat_model_id" in updates: - search_space.chat_model_id = await _validate_role_model_id( + previous_chat_model_id = search_space.chat_model_id + next_chat_model_id = await _validate_role_model_id( session, search_space_id=search_space_id, model_id=updates["chat_model_id"], capability="chat", ) + search_space.chat_model_id = next_chat_model_id + if next_chat_model_id != previous_chat_model_id: + await session.execute( + update(NewChatThread) + .where(NewChatThread.search_space_id == search_space_id) + .values(pinned_llm_config_id=None) + ) + logger.info( + "Cleared auto model pins for search_space_id=%s after chat_model_id change (%s -> %s)", + search_space_id, + previous_chat_model_id, + next_chat_model_id, + ) if "vision_model_id" in updates: search_space.vision_model_id = await _validate_role_model_id( session, diff --git a/surfsense_backend/app/routes/new_llm_config_routes.py b/surfsense_backend/app/routes/new_llm_config_routes.py deleted file mode 100644 index adba5b5ae..000000000 --- a/surfsense_backend/app/routes/new_llm_config_routes.py +++ /dev/null @@ -1,480 +0,0 @@ -""" -API routes for NewLLMConfig CRUD operations. - -NewLLMConfig combines model settings with prompt configuration: -- LLM provider, model, API key, etc. -- Configurable system instructions -- Citation toggle -""" - -import logging - -from fastapi import APIRouter, Depends, HTTPException -from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.future import select - -from app.config import config -from app.db import ( - NewLLMConfig, - Permission, - User, - get_async_session, -) -from app.prompts.default_system_instructions import get_default_system_instructions -from app.schemas import ( - DefaultSystemInstructionsResponse, - GlobalNewLLMConfigRead, - NewLLMConfigCreate, - NewLLMConfigRead, - NewLLMConfigUpdate, -) -from app.services.llm_service import validate_llm_config -from app.services.provider_capabilities import derive_supports_image_input -from app.users import current_active_user -from app.utils.rbac import check_permission - -router = APIRouter() -logger = logging.getLogger(__name__) - - -def _serialize_byok_config(config: NewLLMConfig) -> NewLLMConfigRead: - """Augment a BYOK chat config row with the derived ``supports_image_input``. - - There is no DB column for ``supports_image_input`` — the value is - resolved at the API boundary from LiteLLM's authoritative model map - (default-allow on unknown). Returning ``NewLLMConfigRead`` here keeps - the response shape consistent across list / detail / create / update - endpoints without having to remember to set the field at every call - site. - """ - provider_value = ( - config.provider.value - if hasattr(config.provider, "value") - else str(config.provider) - ) - litellm_params = config.litellm_params or {} - base_model = ( - litellm_params.get("base_model") if isinstance(litellm_params, dict) else None - ) - supports_image_input = derive_supports_image_input( - provider=provider_value.lower(), - model_name=config.model_name, - base_model=base_model, - custom_provider=config.custom_provider, - ) - # ``model_validate`` runs the Pydantic conversion using the ORM - # attribute access path enabled by ``ConfigDict(from_attributes=True)``, - # then we layer the derived field on. ``model_copy(update=...)`` keeps - # the surface immutable from the caller's perspective. - base_read = NewLLMConfigRead.model_validate(config) - return base_read.model_copy(update={"supports_image_input": supports_image_input}) - - -# ============================================================================= -# Global Configs Routes -# ============================================================================= - - -@router.get("/global-new-llm-configs", response_model=list[GlobalNewLLMConfigRead]) -async def get_global_new_llm_configs( - user: User = Depends(current_active_user), -): - """ - Get all available global NewLLMConfig configurations. - These are pre-configured by the system administrator and available to all users. - API keys are not exposed through this endpoint. - - Includes: - - Auto mode (ID 0): Uses LiteLLM Router for automatic load balancing - - Global configs (negative IDs): Individual pre-configured LLM providers - """ - try: - global_configs = config.GLOBAL_LLM_CONFIGS - safe_configs = [] - - # Only include Auto mode if there are actual global configs to route to - # Auto mode requires at least one global config with valid API key - if global_configs and len(global_configs) > 0: - safe_configs.append( - { - "id": 0, - "name": "Auto (Fastest)", - "description": "Automatically routes requests across available LLM providers for optimal performance and rate limit handling. Recommended for most users.", - "provider": "AUTO", - "custom_provider": None, - "model_name": "auto", - "api_base": None, - "litellm_params": {}, - "system_instructions": "", - "use_default_system_instructions": True, - "citations_enabled": True, - "is_global": True, - "is_auto_mode": True, - "billing_tier": "free", - "is_premium": False, - "anonymous_enabled": False, - "seo_enabled": False, - "seo_slug": None, - "seo_title": None, - "seo_description": None, - "quota_reserve_tokens": None, - # Auto routes across the configured pool, which usually - # includes at least one vision-capable deployment, so - # treat Auto as image-capable. The router itself will - # still pick a vision-capable deployment for messages - # carrying image_url blocks (LiteLLM Router falls back - # on ``404`` per its ``allowed_fails`` policy). - "supports_image_input": True, - } - ) - - # Add individual global configs - for cfg in global_configs: - # Capability resolution: explicit value (YAML override or OR - # `_supports_image_input(model)` payload baked in by the - # OpenRouter integration service) wins. Fall back to the - # LiteLLM-driven helper which default-allows on unknown so - # we don't hide vision-capable models that happen to lack a - # YAML annotation. The streaming task safety net is the - # only place a False ever blocks. - if "supports_image_input" in cfg: - supports_image_input = bool(cfg.get("supports_image_input")) - else: - cfg_litellm_params = cfg.get("litellm_params") or {} - cfg_base_model = ( - cfg_litellm_params.get("base_model") - if isinstance(cfg_litellm_params, dict) - else None - ) - supports_image_input = derive_supports_image_input( - provider=cfg.get("provider") or cfg.get("litellm_provider"), - model_name=cfg.get("model_name"), - base_model=cfg_base_model, - custom_provider=cfg.get("custom_provider"), - ) - - safe_config = { - "id": cfg.get("id"), - "name": cfg.get("name"), - "description": cfg.get("description"), - "provider": cfg.get("provider") or cfg.get("litellm_provider"), - "custom_provider": cfg.get("custom_provider"), - "model_name": cfg.get("model_name"), - "api_base": cfg.get("api_base") or None, - "litellm_params": cfg.get("litellm_params", {}), - # New prompt configuration fields - "system_instructions": cfg.get("system_instructions", ""), - "use_default_system_instructions": cfg.get( - "use_default_system_instructions", True - ), - "citations_enabled": cfg.get("citations_enabled", True), - "is_global": True, - "billing_tier": cfg.get("billing_tier", "free"), - "is_premium": cfg.get("billing_tier", "free") == "premium", - "anonymous_enabled": cfg.get("anonymous_enabled", False), - "seo_enabled": cfg.get("seo_enabled", False), - "seo_slug": cfg.get("seo_slug"), - "seo_title": cfg.get("seo_title"), - "seo_description": cfg.get("seo_description"), - "quota_reserve_tokens": cfg.get("quota_reserve_tokens"), - "supports_image_input": supports_image_input, - } - safe_configs.append(safe_config) - - return safe_configs - except Exception as e: - logger.exception("Failed to fetch global NewLLMConfigs") - raise HTTPException( - status_code=500, detail=f"Failed to fetch global configurations: {e!s}" - ) from e - - -# ============================================================================= -# CRUD Routes -# ============================================================================= - - -@router.post("/new-llm-configs", response_model=NewLLMConfigRead) -async def create_new_llm_config( - config_data: NewLLMConfigCreate, - session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), -): - """ - Create a new NewLLMConfig for a search space. - Requires LLM_CONFIGS_CREATE permission. - """ - try: - # Verify user has permission - await check_permission( - session, - user, - config_data.search_space_id, - Permission.LLM_CONFIGS_CREATE.value, - "You don't have permission to create LLM configurations in this search space", - ) - - # Validate the LLM configuration by making a test API call - is_valid, error_message = await validate_llm_config( - provider=config_data.provider.value, - model_name=config_data.model_name, - api_key=config_data.api_key, - api_base=config_data.api_base, - custom_provider=config_data.custom_provider, - litellm_params=config_data.litellm_params, - ) - - if not is_valid: - raise HTTPException( - status_code=400, - detail=f"Invalid LLM configuration: {error_message}", - ) - - # Create the config with user association - db_config = NewLLMConfig(**config_data.model_dump(), user_id=user.id) - session.add(db_config) - await session.commit() - await session.refresh(db_config) - - return _serialize_byok_config(db_config) - - except HTTPException: - raise - except Exception as e: - await session.rollback() - logger.exception("Failed to create NewLLMConfig") - raise HTTPException( - status_code=500, detail=f"Failed to create configuration: {e!s}" - ) from e - - -@router.get("/new-llm-configs", response_model=list[NewLLMConfigRead]) -async def list_new_llm_configs( - search_space_id: int, - skip: int = 0, - limit: int = 100, - session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), -): - """ - Get all NewLLMConfigs for a search space. - Requires LLM_CONFIGS_READ permission. - """ - try: - # Verify user has permission - await check_permission( - session, - user, - search_space_id, - Permission.LLM_CONFIGS_READ.value, - "You don't have permission to view LLM configurations in this search space", - ) - - result = await session.execute( - select(NewLLMConfig) - .filter(NewLLMConfig.search_space_id == search_space_id) - .order_by(NewLLMConfig.created_at.desc()) - .offset(skip) - .limit(limit) - ) - - return [_serialize_byok_config(cfg) for cfg in result.scalars().all()] - - except HTTPException: - raise - except Exception as e: - logger.exception("Failed to list NewLLMConfigs") - raise HTTPException( - status_code=500, detail=f"Failed to fetch configurations: {e!s}" - ) from e - - -@router.get( - "/new-llm-configs/default-system-instructions", - response_model=DefaultSystemInstructionsResponse, -) -async def get_default_system_instructions_endpoint( - user: User = Depends(current_active_user), -): - """ - Get the default SURFSENSE_SYSTEM_INSTRUCTIONS template. - Useful for pre-populating the UI when creating a new configuration. - """ - return DefaultSystemInstructionsResponse( - default_system_instructions=get_default_system_instructions() - ) - - -@router.get("/new-llm-configs/{config_id}", response_model=NewLLMConfigRead) -async def get_new_llm_config( - config_id: int, - session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), -): - """ - Get a specific NewLLMConfig by ID. - Requires LLM_CONFIGS_READ permission. - """ - try: - result = await session.execute( - select(NewLLMConfig).filter(NewLLMConfig.id == config_id) - ) - config = result.scalars().first() - - if not config: - raise HTTPException(status_code=404, detail="Configuration not found") - - # Verify user has permission - await check_permission( - session, - user, - config.search_space_id, - Permission.LLM_CONFIGS_READ.value, - "You don't have permission to view LLM configurations in this search space", - ) - - return _serialize_byok_config(config) - - except HTTPException: - raise - except Exception as e: - logger.exception("Failed to get NewLLMConfig") - raise HTTPException( - status_code=500, detail=f"Failed to fetch configuration: {e!s}" - ) from e - - -@router.put("/new-llm-configs/{config_id}", response_model=NewLLMConfigRead) -async def update_new_llm_config( - config_id: int, - update_data: NewLLMConfigUpdate, - session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), -): - """ - Update an existing NewLLMConfig. - Requires LLM_CONFIGS_UPDATE permission. - """ - try: - result = await session.execute( - select(NewLLMConfig).filter(NewLLMConfig.id == config_id) - ) - config = result.scalars().first() - - if not config: - raise HTTPException(status_code=404, detail="Configuration not found") - - # Verify user has permission - await check_permission( - session, - user, - config.search_space_id, - Permission.LLM_CONFIGS_UPDATE.value, - "You don't have permission to update LLM configurations in this search space", - ) - - update_dict = update_data.model_dump(exclude_unset=True) - - # If updating LLM settings, validate them - if any( - key in update_dict - for key in [ - "provider", - "model_name", - "api_key", - "api_base", - "custom_provider", - "litellm_params", - ] - ): - # Build the validation config from existing + updates - validation_config = { - "provider": update_dict.get("provider", config.provider).value - if hasattr(update_dict.get("provider", config.provider), "value") - else update_dict.get("provider", config.provider.value), - "model_name": update_dict.get("model_name", config.model_name), - "api_key": update_dict.get("api_key", config.api_key), - "api_base": update_dict.get("api_base", config.api_base), - "custom_provider": update_dict.get( - "custom_provider", config.custom_provider - ), - "litellm_params": update_dict.get( - "litellm_params", config.litellm_params - ), - } - - is_valid, error_message = await validate_llm_config( - provider=validation_config["provider"], - model_name=validation_config["model_name"], - api_key=validation_config["api_key"], - api_base=validation_config["api_base"], - custom_provider=validation_config["custom_provider"], - litellm_params=validation_config["litellm_params"], - ) - - if not is_valid: - raise HTTPException( - status_code=400, - detail=f"Invalid LLM configuration: {error_message}", - ) - - # Apply updates - for key, value in update_dict.items(): - setattr(config, key, value) - - await session.commit() - await session.refresh(config) - - return _serialize_byok_config(config) - - except HTTPException: - raise - except Exception as e: - await session.rollback() - logger.exception("Failed to update NewLLMConfig") - raise HTTPException( - status_code=500, detail=f"Failed to update configuration: {e!s}" - ) from e - - -@router.delete("/new-llm-configs/{config_id}", response_model=dict) -async def delete_new_llm_config( - config_id: int, - session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), -): - """ - Delete a NewLLMConfig. - Requires LLM_CONFIGS_DELETE permission. - """ - try: - result = await session.execute( - select(NewLLMConfig).filter(NewLLMConfig.id == config_id) - ) - config = result.scalars().first() - - if not config: - raise HTTPException(status_code=404, detail="Configuration not found") - - # Verify user has permission - await check_permission( - session, - user, - config.search_space_id, - Permission.LLM_CONFIGS_DELETE.value, - "You don't have permission to delete LLM configurations in this search space", - ) - - await session.delete(config) - await session.commit() - - return {"message": "Configuration deleted successfully", "id": config_id} - - except HTTPException: - raise - except Exception as e: - await session.rollback() - logger.exception("Failed to delete NewLLMConfig") - raise HTTPException( - status_code=500, detail=f"Failed to delete configuration: {e!s}" - ) from e diff --git a/surfsense_backend/app/routes/search_spaces_routes.py b/surfsense_backend/app/routes/search_spaces_routes.py index 7c5fbf28b..592a9dd0e 100644 --- a/surfsense_backend/app/routes/search_spaces_routes.py +++ b/surfsense_backend/app/routes/search_spaces_routes.py @@ -1,27 +1,20 @@ import logging from fastapi import APIRouter, Depends, HTTPException -from sqlalchemy import func, update +from sqlalchemy import func from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.future import select -from app.config import config from app.db import ( - ImageGenerationConfig, - NewChatThread, - NewLLMConfig, Permission, SearchSpace, SearchSpaceMembership, SearchSpaceRole, User, - VisionLLMConfig, get_async_session, get_default_roles_config, ) from app.schemas import ( - LLMPreferencesRead, - LLMPreferencesUpdate, SearchSpaceCreate, SearchSpaceRead, SearchSpaceUpdate, @@ -377,357 +370,6 @@ async def delete_search_space( ) from e -# ============================================================================= -# LLM Preferences Routes -# ============================================================================= - - -async def _get_llm_config_by_id( - session: AsyncSession, config_id: int | None -) -> dict | None: - """ - Get an LLM config by ID as a dictionary. Returns database config for positive IDs, - global config for negative IDs, Auto mode config for ID 0, or None if ID is None. - """ - if config_id is None: - return None - - # Auto mode (ID 0) - uses LiteLLM Router for load balancing - if config_id == 0: - return { - "id": 0, - "name": "Auto (Fastest)", - "description": "Automatically routes requests across available LLM providers for optimal performance and rate limit handling", - "provider": "AUTO", - "custom_provider": None, - "model_name": "auto", - "api_base": None, - "litellm_params": {}, - "system_instructions": "", - "use_default_system_instructions": True, - "citations_enabled": True, - "is_global": True, - "is_auto_mode": True, - } - - if config_id < 0: - # Global config - find from YAML - global_configs = config.GLOBAL_LLM_CONFIGS - for cfg in global_configs: - if cfg.get("id") == config_id: - return { - "id": cfg.get("id"), - "name": cfg.get("name"), - "description": cfg.get("description"), - "provider": cfg.get("provider") or cfg.get("litellm_provider"), - "custom_provider": cfg.get("custom_provider"), - "model_name": cfg.get("model_name"), - "api_base": cfg.get("api_base"), - "litellm_params": cfg.get("litellm_params", {}), - "system_instructions": cfg.get("system_instructions", ""), - "use_default_system_instructions": cfg.get( - "use_default_system_instructions", True - ), - "citations_enabled": cfg.get("citations_enabled", True), - "is_global": True, - } - return None - else: - # Database config - convert to dict - result = await session.execute( - select(NewLLMConfig).filter(NewLLMConfig.id == config_id) - ) - db_config = result.scalars().first() - if db_config: - return { - "id": db_config.id, - "name": db_config.name, - "description": db_config.description, - "provider": db_config.provider.value if db_config.provider else None, - "custom_provider": db_config.custom_provider, - "model_name": db_config.model_name, - "api_key": db_config.api_key, - "api_base": db_config.api_base, - "litellm_params": db_config.litellm_params or {}, - "system_instructions": db_config.system_instructions or "", - "use_default_system_instructions": db_config.use_default_system_instructions, - "citations_enabled": db_config.citations_enabled, - "created_at": db_config.created_at.isoformat() - if db_config.created_at - else None, - "search_space_id": db_config.search_space_id, - } - return None - - -async def _get_image_gen_config_by_id( - session: AsyncSession, config_id: int | None -) -> dict | None: - """ - Get an image generation config by ID as a dictionary. - Returns Auto mode for ID 0, global config for negative IDs, - DB ImageGenerationConfig for positive IDs, or None. - """ - if config_id is None: - return None - - if config_id == 0: - return { - "id": 0, - "name": "Auto (Fastest)", - "description": "Automatically routes requests across available image generation providers", - "provider": "AUTO", - "model_name": "auto", - "is_global": True, - "is_auto_mode": True, - "billing_tier": "free", - } - - if config_id < 0: - for cfg in config.GLOBAL_IMAGE_GEN_CONFIGS: - if cfg.get("id") == config_id: - return { - "id": cfg.get("id"), - "name": cfg.get("name"), - "description": cfg.get("description"), - "provider": cfg.get("provider") or cfg.get("litellm_provider"), - "custom_provider": cfg.get("custom_provider"), - "model_name": cfg.get("model_name"), - "api_base": cfg.get("api_base") or None, - "api_version": cfg.get("api_version") or None, - "litellm_params": cfg.get("litellm_params", {}), - "is_global": True, - "billing_tier": cfg.get("billing_tier", "free"), - } - return None - - # Positive ID: query ImageGenerationConfig table - result = await session.execute( - select(ImageGenerationConfig).filter(ImageGenerationConfig.id == config_id) - ) - db_config = result.scalars().first() - if db_config: - return { - "id": db_config.id, - "name": db_config.name, - "description": db_config.description, - "provider": db_config.provider.value if db_config.provider else None, - "custom_provider": db_config.custom_provider, - "model_name": db_config.model_name, - "api_base": db_config.api_base, - "api_version": db_config.api_version, - "litellm_params": db_config.litellm_params or {}, - "created_at": db_config.created_at.isoformat() - if db_config.created_at - else None, - "search_space_id": db_config.search_space_id, - } - return None - - -async def _get_vision_llm_config_by_id( - session: AsyncSession, config_id: int | None -) -> dict | None: - if config_id is None: - return None - - if config_id == 0: - return { - "id": 0, - "name": "Auto (Fastest)", - "description": "Automatically routes requests across available vision LLM providers", - "provider": "AUTO", - "model_name": "auto", - "is_global": True, - "is_auto_mode": True, - "billing_tier": "free", - } - - if config_id < 0: - for cfg in config.GLOBAL_VISION_LLM_CONFIGS: - if cfg.get("id") == config_id: - return { - "id": cfg.get("id"), - "name": cfg.get("name"), - "description": cfg.get("description"), - "provider": cfg.get("provider") or cfg.get("litellm_provider"), - "custom_provider": cfg.get("custom_provider"), - "model_name": cfg.get("model_name"), - "api_base": cfg.get("api_base") or None, - "api_version": cfg.get("api_version") or None, - "litellm_params": cfg.get("litellm_params", {}), - "is_global": True, - "billing_tier": cfg.get("billing_tier", "free"), - } - return None - - result = await session.execute( - select(VisionLLMConfig).filter(VisionLLMConfig.id == config_id) - ) - db_config = result.scalars().first() - if db_config: - return { - "id": db_config.id, - "name": db_config.name, - "description": db_config.description, - "provider": db_config.provider.value if db_config.provider else None, - "custom_provider": db_config.custom_provider, - "model_name": db_config.model_name, - "api_base": db_config.api_base, - "api_version": db_config.api_version, - "litellm_params": db_config.litellm_params or {}, - "created_at": db_config.created_at.isoformat() - if db_config.created_at - else None, - "search_space_id": db_config.search_space_id, - } - return None - - -@router.get( - "/search-spaces/{search_space_id}/llm-preferences", - response_model=LLMPreferencesRead, -) -async def get_llm_preferences( - search_space_id: int, - session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), -): - """ - Get LLM preferences (role assignments) for a search space. - Requires LLM_CONFIGS_READ permission. - """ - try: - # Check permission - await check_permission( - session, - user, - search_space_id, - Permission.LLM_CONFIGS_READ.value, - "You don't have permission to view LLM preferences", - ) - - result = await session.execute( - select(SearchSpace).filter(SearchSpace.id == search_space_id) - ) - search_space = result.scalars().first() - - if not search_space: - raise HTTPException(status_code=404, detail="Search space not found") - - # Get full config objects for each role - agent_llm = await _get_llm_config_by_id(session, search_space.agent_llm_id) - image_generation_config = await _get_image_gen_config_by_id( - session, search_space.image_generation_config_id - ) - vision_llm_config = await _get_vision_llm_config_by_id( - session, search_space.vision_llm_config_id - ) - - return LLMPreferencesRead( - agent_llm_id=search_space.agent_llm_id, - image_generation_config_id=search_space.image_generation_config_id, - vision_llm_config_id=search_space.vision_llm_config_id, - agent_llm=agent_llm, - image_generation_config=image_generation_config, - vision_llm_config=vision_llm_config, - ) - - except HTTPException: - raise - except Exception as e: - logger.exception("Failed to get LLM preferences") - raise HTTPException( - status_code=500, detail=f"Failed to get LLM preferences: {e!s}" - ) from e - - -@router.put( - "/search-spaces/{search_space_id}/llm-preferences", - response_model=LLMPreferencesRead, -) -async def update_llm_preferences( - search_space_id: int, - preferences: LLMPreferencesUpdate, - session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), -): - """ - Update LLM preferences (role assignments) for a search space. - Requires LLM_CONFIGS_UPDATE permission. - """ - try: - # Check permission - await check_permission( - session, - user, - search_space_id, - Permission.LLM_CONFIGS_UPDATE.value, - "You don't have permission to update LLM preferences", - ) - - result = await session.execute( - select(SearchSpace).filter(SearchSpace.id == search_space_id) - ) - search_space = result.scalars().first() - - if not search_space: - raise HTTPException(status_code=404, detail="Search space not found") - - # Update preferences - update_data = preferences.model_dump(exclude_unset=True) - previous_agent_llm_id = search_space.agent_llm_id - for key, value in update_data.items(): - setattr(search_space, key, value) - - agent_llm_changed = ( - "agent_llm_id" in update_data - and update_data["agent_llm_id"] != previous_agent_llm_id - ) - if agent_llm_changed: - await session.execute( - update(NewChatThread) - .where(NewChatThread.search_space_id == search_space_id) - .values(pinned_llm_config_id=None) - ) - logger.info( - "Cleared auto model pins for search_space_id=%s after agent_llm_id change (%s -> %s)", - search_space_id, - previous_agent_llm_id, - update_data["agent_llm_id"], - ) - - await session.commit() - await session.refresh(search_space) - - # Get full config objects for response - agent_llm = await _get_llm_config_by_id(session, search_space.agent_llm_id) - image_generation_config = await _get_image_gen_config_by_id( - session, search_space.image_generation_config_id - ) - vision_llm_config = await _get_vision_llm_config_by_id( - session, search_space.vision_llm_config_id - ) - - return LLMPreferencesRead( - agent_llm_id=search_space.agent_llm_id, - image_generation_config_id=search_space.image_generation_config_id, - vision_llm_config_id=search_space.vision_llm_config_id, - agent_llm=agent_llm, - image_generation_config=image_generation_config, - vision_llm_config=vision_llm_config, - ) - - except HTTPException: - raise - except Exception as e: - await session.rollback() - logger.exception("Failed to update LLM preferences") - raise HTTPException( - status_code=500, detail=f"Failed to update LLM preferences: {e!s}" - ) from e - - @router.get("/searchspaces/{search_space_id}/snapshots") async def list_search_space_snapshots( search_space_id: int, diff --git a/surfsense_backend/app/routes/vision_llm_routes.py b/surfsense_backend/app/routes/vision_llm_routes.py deleted file mode 100644 index b93d25b9c..000000000 --- a/surfsense_backend/app/routes/vision_llm_routes.py +++ /dev/null @@ -1,304 +0,0 @@ -import logging - -from fastapi import APIRouter, Depends, HTTPException -from pydantic import BaseModel -from sqlalchemy import select -from sqlalchemy.ext.asyncio import AsyncSession - -from app.config import config -from app.db import ( - Permission, - User, - VisionLLMConfig, - get_async_session, -) -from app.schemas import ( - GlobalVisionLLMConfigRead, - VisionLLMConfigCreate, - VisionLLMConfigRead, - VisionLLMConfigUpdate, -) -from app.services.vision_model_list_service import get_vision_model_list -from app.users import current_active_user -from app.utils.rbac import check_permission - -router = APIRouter() -logger = logging.getLogger(__name__) - - -# ============================================================================= -# Vision Model Catalogue (from OpenRouter, filtered for image-input models) -# ============================================================================= - - -class VisionModelListItem(BaseModel): - value: str - label: str - provider: str - context_window: str | None = None - - -@router.get("/vision-models", response_model=list[VisionModelListItem]) -async def list_vision_models( - user: User = Depends(current_active_user), -): - """Return vision-capable models sourced from OpenRouter (filtered by image input).""" - try: - return await get_vision_model_list() - except Exception as e: - logger.exception("Failed to fetch vision model list") - raise HTTPException( - status_code=500, detail=f"Failed to fetch vision model list: {e!s}" - ) from e - - -# ============================================================================= -# Global Vision LLM Configs (from YAML) -# ============================================================================= - - -@router.get( - "/global-vision-llm-configs", - response_model=list[GlobalVisionLLMConfigRead], -) -async def get_global_vision_llm_configs( - user: User = Depends(current_active_user), -): - try: - global_configs = config.GLOBAL_VISION_LLM_CONFIGS - safe_configs = [] - - if global_configs and len(global_configs) > 0: - safe_configs.append( - { - "id": 0, - "name": "Auto (Fastest)", - "description": "Automatically routes across available vision LLM providers.", - "provider": "AUTO", - "custom_provider": None, - "model_name": "auto", - "api_base": None, - "api_version": None, - "litellm_params": {}, - "is_global": True, - "is_auto_mode": True, - # Auto mode treated as free until per-deployment billing-tier - # surfacing lands; see ``get_vision_llm`` for parity. - "billing_tier": "free", - "is_premium": False, - } - ) - - for cfg in global_configs: - billing_tier = str(cfg.get("billing_tier", "free")).lower() - safe_configs.append( - { - "id": cfg.get("id"), - "name": cfg.get("name"), - "description": cfg.get("description"), - "provider": cfg.get("provider") or cfg.get("litellm_provider"), - "custom_provider": cfg.get("custom_provider"), - "model_name": cfg.get("model_name"), - "api_base": cfg.get("api_base") or None, - "api_version": cfg.get("api_version") or None, - "litellm_params": cfg.get("litellm_params", {}), - "is_global": True, - "billing_tier": billing_tier, - # Mirror chat (``new_llm_config_routes``) so the new-chat - # selector's premium badge logic keys off the same - # field across chat / image / vision tabs. - "is_premium": billing_tier == "premium", - "quota_reserve_tokens": cfg.get("quota_reserve_tokens"), - "input_cost_per_token": cfg.get("input_cost_per_token"), - "output_cost_per_token": cfg.get("output_cost_per_token"), - } - ) - - return safe_configs - except Exception as e: - logger.exception("Failed to fetch global vision LLM configs") - raise HTTPException( - status_code=500, detail=f"Failed to fetch configs: {e!s}" - ) from e - - -# ============================================================================= -# VisionLLMConfig CRUD -# ============================================================================= - - -@router.post("/vision-llm-configs", response_model=VisionLLMConfigRead) -async def create_vision_llm_config( - config_data: VisionLLMConfigCreate, - session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), -): - try: - await check_permission( - session, - user, - config_data.search_space_id, - Permission.VISION_CONFIGS_CREATE.value, - "You don't have permission to create vision LLM configs in this search space", - ) - - db_config = VisionLLMConfig(**config_data.model_dump(), user_id=user.id) - session.add(db_config) - await session.commit() - await session.refresh(db_config) - return db_config - - except HTTPException: - raise - except Exception as e: - await session.rollback() - logger.exception("Failed to create VisionLLMConfig") - raise HTTPException( - status_code=500, detail=f"Failed to create config: {e!s}" - ) from e - - -@router.get("/vision-llm-configs", response_model=list[VisionLLMConfigRead]) -async def list_vision_llm_configs( - search_space_id: int, - skip: int = 0, - limit: int = 100, - session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), -): - try: - await check_permission( - session, - user, - search_space_id, - Permission.VISION_CONFIGS_READ.value, - "You don't have permission to view vision LLM configs in this search space", - ) - - result = await session.execute( - select(VisionLLMConfig) - .filter(VisionLLMConfig.search_space_id == search_space_id) - .order_by(VisionLLMConfig.created_at.desc()) - .offset(skip) - .limit(limit) - ) - return result.scalars().all() - - except HTTPException: - raise - except Exception as e: - logger.exception("Failed to list VisionLLMConfigs") - raise HTTPException( - status_code=500, detail=f"Failed to fetch configs: {e!s}" - ) from e - - -@router.get("/vision-llm-configs/{config_id}", response_model=VisionLLMConfigRead) -async def get_vision_llm_config( - config_id: int, - session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), -): - try: - result = await session.execute( - select(VisionLLMConfig).filter(VisionLLMConfig.id == config_id) - ) - db_config = result.scalars().first() - if not db_config: - raise HTTPException(status_code=404, detail="Config not found") - - await check_permission( - session, - user, - db_config.search_space_id, - Permission.VISION_CONFIGS_READ.value, - "You don't have permission to view vision LLM configs in this search space", - ) - return db_config - - except HTTPException: - raise - except Exception as e: - logger.exception("Failed to get VisionLLMConfig") - raise HTTPException( - status_code=500, detail=f"Failed to fetch config: {e!s}" - ) from e - - -@router.put("/vision-llm-configs/{config_id}", response_model=VisionLLMConfigRead) -async def update_vision_llm_config( - config_id: int, - update_data: VisionLLMConfigUpdate, - session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), -): - try: - result = await session.execute( - select(VisionLLMConfig).filter(VisionLLMConfig.id == config_id) - ) - db_config = result.scalars().first() - if not db_config: - raise HTTPException(status_code=404, detail="Config not found") - - await check_permission( - session, - user, - db_config.search_space_id, - Permission.VISION_CONFIGS_CREATE.value, - "You don't have permission to update vision LLM configs in this search space", - ) - - for key, value in update_data.model_dump(exclude_unset=True).items(): - setattr(db_config, key, value) - - await session.commit() - await session.refresh(db_config) - return db_config - - except HTTPException: - raise - except Exception as e: - await session.rollback() - logger.exception("Failed to update VisionLLMConfig") - raise HTTPException( - status_code=500, detail=f"Failed to update config: {e!s}" - ) from e - - -@router.delete("/vision-llm-configs/{config_id}", response_model=dict) -async def delete_vision_llm_config( - config_id: int, - session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), -): - try: - result = await session.execute( - select(VisionLLMConfig).filter(VisionLLMConfig.id == config_id) - ) - db_config = result.scalars().first() - if not db_config: - raise HTTPException(status_code=404, detail="Config not found") - - await check_permission( - session, - user, - db_config.search_space_id, - Permission.VISION_CONFIGS_DELETE.value, - "You don't have permission to delete vision LLM configs in this search space", - ) - - await session.delete(db_config) - await session.commit() - return { - "message": "Vision LLM config deleted successfully", - "id": config_id, - } - - except HTTPException: - raise - except Exception as e: - await session.rollback() - logger.exception("Failed to delete VisionLLMConfig") - raise HTTPException( - status_code=500, detail=f"Failed to delete config: {e!s}" - ) from e diff --git a/surfsense_backend/app/schemas/__init__.py b/surfsense_backend/app/schemas/__init__.py index 3c4fdfa83..f577397b6 100644 --- a/surfsense_backend/app/schemas/__init__.py +++ b/surfsense_backend/app/schemas/__init__.py @@ -34,11 +34,6 @@ from .folders import ( ) from .google_drive import DriveItem, GoogleDriveIndexingOptions, GoogleDriveIndexRequest from .image_generation import ( - GlobalImageGenConfigRead, - ImageGenerationConfigCreate, - ImageGenerationConfigPublic, - ImageGenerationConfigRead, - ImageGenerationConfigUpdate, ImageGenerationCreate, ImageGenerationListRead, ImageGenerationRead, @@ -74,16 +69,6 @@ from .new_chat import ( ThreadListItem, ThreadListResponse, ) -from .new_llm_config import ( - DefaultSystemInstructionsResponse, - GlobalNewLLMConfigRead, - LLMPreferencesRead, - LLMPreferencesUpdate, - NewLLMConfigCreate, - NewLLMConfigPublic, - NewLLMConfigRead, - NewLLMConfigUpdate, -) from .rbac_schemas import ( InviteAcceptRequest, InviteAcceptResponse, @@ -142,14 +127,6 @@ from .video_presentations import ( VideoPresentationRead, VideoPresentationUpdate, ) -from .vision_llm import ( - GlobalVisionLLMConfigRead, - VisionLLMConfigCreate, - VisionLLMConfigPublic, - VisionLLMConfigRead, - VisionLLMConfigUpdate, -) - __all__ = [ # Folder schemas "BulkDocumentMove", @@ -169,7 +146,6 @@ __all__ = [ "CreditPurchaseHistoryResponse", "CreditPurchaseRead", "CreditStripeStatusResponse", - "DefaultSystemInstructionsResponse", # Document schemas "DocumentBase", "DocumentMove", @@ -192,19 +168,10 @@ __all__ = [ "FolderRead", "FolderReorder", "FolderUpdate", - "GlobalImageGenConfigRead", - "GlobalNewLLMConfigRead", - # Vision LLM Config schemas - "GlobalVisionLLMConfigRead", "GoogleDriveIndexRequest", "GoogleDriveIndexingOptions", # Base schemas "IDModel", - # Image Generation Config schemas - "ImageGenerationConfigCreate", - "ImageGenerationConfigPublic", - "ImageGenerationConfigRead", - "ImageGenerationConfigUpdate", # Image Generation schemas "ImageGenerationCreate", "ImageGenerationListRead", @@ -216,9 +183,6 @@ __all__ = [ "InviteInfoResponse", "InviteRead", "InviteUpdate", - # LLM Preferences schemas - "LLMPreferencesRead", - "LLMPreferencesUpdate", # Log schemas "LogBase", "LogCreate", @@ -255,11 +219,6 @@ __all__ = [ "NewChatThreadRead", "NewChatThreadUpdate", "NewChatThreadWithMessages", - # NewLLMConfig schemas - "NewLLMConfigCreate", - "NewLLMConfigPublic", - "NewLLMConfigRead", - "NewLLMConfigUpdate", "PagePurchaseHistoryResponse", "PagePurchaseRead", "PaginatedResponse", @@ -303,8 +262,4 @@ __all__ = [ "VideoPresentationCreate", "VideoPresentationRead", "VideoPresentationUpdate", - "VisionLLMConfigCreate", - "VisionLLMConfigPublic", - "VisionLLMConfigRead", - "VisionLLMConfigUpdate", ] diff --git a/surfsense_backend/app/schemas/image_generation.py b/surfsense_backend/app/schemas/image_generation.py index 4262b2b3f..83671cc77 100644 --- a/surfsense_backend/app/schemas/image_generation.py +++ b/surfsense_backend/app/schemas/image_generation.py @@ -1,109 +1,10 @@ -""" -Pydantic schemas for Image Generation configs and generation requests. +"""Pydantic schemas for image generation requests/results.""" -ImageGenerationConfig: CRUD schemas for user-created image gen model configs. -ImageGeneration: Schemas for the actual image generation requests/results. -GlobalImageGenConfigRead: Schema for admin-configured YAML configs. -""" - -import uuid from datetime import datetime from typing import Any from pydantic import BaseModel, ConfigDict, Field -from app.db import ImageGenProvider - -# ============================================================================= -# ImageGenerationConfig CRUD Schemas -# ============================================================================= - - -class ImageGenerationConfigBase(BaseModel): - """Base schema with fields for ImageGenerationConfig.""" - - name: str = Field( - ..., max_length=100, description="User-friendly name for the config" - ) - description: str | None = Field( - None, max_length=500, description="Optional description" - ) - provider: ImageGenProvider = Field( - ..., - description="Image generation provider (OpenAI, Azure, Google AI Studio, Vertex AI, Bedrock, Recraft, OpenRouter, Xinference, Nscale)", - ) - custom_provider: str | None = Field( - None, max_length=100, description="Custom provider name" - ) - model_name: str = Field( - ..., max_length=100, description="Model name (e.g., dall-e-3, gpt-image-1)" - ) - api_key: str = Field(..., description="API key for the provider") - api_base: str | None = Field( - None, max_length=500, description="Optional API base URL" - ) - api_version: str | None = Field( - None, - max_length=50, - description="Azure-specific API version (e.g., '2024-02-15-preview')", - ) - litellm_params: dict[str, Any] | None = Field( - default=None, description="Additional LiteLLM parameters" - ) - - -class ImageGenerationConfigCreate(ImageGenerationConfigBase): - """Schema for creating a new ImageGenerationConfig.""" - - search_space_id: int = Field( - ..., description="Search space ID to associate the config with" - ) - - -class ImageGenerationConfigUpdate(BaseModel): - """Schema for updating an existing ImageGenerationConfig. All fields optional.""" - - name: str | None = Field(None, max_length=100) - description: str | None = Field(None, max_length=500) - provider: ImageGenProvider | None = None - custom_provider: str | None = Field(None, max_length=100) - model_name: str | None = Field(None, max_length=100) - api_key: str | None = None - api_base: str | None = Field(None, max_length=500) - api_version: str | None = Field(None, max_length=50) - litellm_params: dict[str, Any] | None = None - - -class ImageGenerationConfigRead(ImageGenerationConfigBase): - """Schema for reading an ImageGenerationConfig (includes id and timestamps).""" - - id: int - created_at: datetime - search_space_id: int - user_id: uuid.UUID - - model_config = ConfigDict(from_attributes=True) - - -class ImageGenerationConfigPublic(BaseModel): - """Public schema that hides the API key (for list views).""" - - id: int - name: str - description: str | None = None - provider: ImageGenProvider - custom_provider: str | None = None - model_name: str - api_base: str | None = None - api_version: str | None = None - litellm_params: dict[str, Any] | None = None - created_at: datetime - search_space_id: int - user_id: uuid.UUID - - model_config = ConfigDict(from_attributes=True) - - # ============================================================================= # ImageGeneration (request/result) Schemas # ============================================================================= @@ -136,12 +37,12 @@ class ImageGenerationCreate(BaseModel): search_space_id: int = Field( ..., description="Search space ID to associate the generation with" ) - image_generation_config_id: int | None = Field( + image_gen_model_id: int | None = Field( None, description=( - "Image generation config ID. " - "0 = Auto mode (router), negative = global YAML config, positive = DB config. " - "If not provided, uses the search space's image_generation_config_id preference." + "Image generation model ID. " + "0 = Auto mode, negative = GLOBAL model, positive = BYOK Model row. " + "If not provided, uses the search space's image_gen_model_id preference." ), ) @@ -157,7 +58,7 @@ class ImageGenerationRead(BaseModel): size: str | None = None style: str | None = None response_format: str | None = None - image_generation_config_id: int | None = None + image_gen_model_id: int | None = None response_data: dict[str, Any] | None = None error_message: str | None = None search_space_id: int @@ -204,57 +105,3 @@ class ImageGenerationListRead(BaseModel): image_count=image_count, ) - -# ============================================================================= -# Global Image Gen Config (from YAML) -# ============================================================================= - - -class GlobalImageGenConfigRead(BaseModel): - """ - Schema for reading global image generation configs from YAML. - Global configs have negative IDs. API key is hidden. - ID 0 is reserved for Auto mode (LiteLLM Router load balancing). - - The ``billing_tier`` field allows the frontend to show a Premium/Free - badge and (more importantly) tells the backend whether to debit the - user's premium credit pool when this config is used. ``"free"`` is - the default for backward compatibility — admins must explicitly opt - a global config into ``"premium"``. - """ - - id: int = Field( - ..., - description="Config ID: 0 for Auto mode, negative for global configs", - ) - name: str - description: str | None = None - provider: str - custom_provider: str | None = None - model_name: str - api_base: str | None = None - api_version: str | None = None - litellm_params: dict[str, Any] | None = None - is_global: bool = True - is_auto_mode: bool = False - billing_tier: str = Field( - default="free", - description="'free' or 'premium'. Premium debits the user's premium credit pool (USD-cost-based).", - ) - is_premium: bool = Field( - default=False, - description=( - "Convenience boolean derived server-side from " - "``billing_tier == 'premium'``. The new-chat model selector " - "keys its Free/Premium badge off this field for parity with " - "chat (`GlobalLLMConfigRead.is_premium`)." - ), - ) - quota_reserve_micros: int | None = Field( - default=None, - description=( - "Optional override for the reservation amount (in micro-USD) used when " - "this image generation is premium. Falls back to " - "QUOTA_DEFAULT_IMAGE_RESERVE_MICROS when omitted." - ), - ) diff --git a/surfsense_backend/app/schemas/new_llm_config.py b/surfsense_backend/app/schemas/new_llm_config.py deleted file mode 100644 index 2f04a9e66..000000000 --- a/surfsense_backend/app/schemas/new_llm_config.py +++ /dev/null @@ -1,256 +0,0 @@ -""" -Pydantic schemas for the NewLLMConfig API. - -NewLLMConfig combines model settings with prompt configuration: -- LLM provider, model, API key, etc. -- Configurable system instructions -- Citation toggle -""" - -import uuid -from datetime import datetime -from typing import Any - -from pydantic import BaseModel, ConfigDict, Field - -from app.db import LiteLLMProvider - - -class NewLLMConfigBase(BaseModel): - """Base schema with common fields for NewLLMConfig.""" - - name: str = Field( - ..., max_length=100, description="User-friendly name for the configuration" - ) - description: str | None = Field( - None, max_length=500, description="Optional description" - ) - - # Model Configuration - provider: LiteLLMProvider = Field(..., description="LiteLLM provider type") - custom_provider: str | None = Field( - None, max_length=100, description="Custom provider name when provider is CUSTOM" - ) - model_name: str = Field( - ..., max_length=100, description="Model name without provider prefix" - ) - api_key: str = Field(..., description="API key for the provider") - api_base: str | None = Field( - None, max_length=500, description="Optional API base URL" - ) - litellm_params: dict[str, Any] | None = Field( - default=None, description="Additional LiteLLM parameters" - ) - - # Prompt Configuration - system_instructions: str = Field( - default="", - description="Custom system instructions. Empty string uses default SURFSENSE_SYSTEM_INSTRUCTIONS.", - ) - use_default_system_instructions: bool = Field( - default=True, - description="Whether to use default instructions when system_instructions is empty", - ) - citations_enabled: bool = Field( - default=True, - description="Whether to include citation instructions in the system prompt", - ) - - -class NewLLMConfigCreate(NewLLMConfigBase): - """Schema for creating a new NewLLMConfig.""" - - search_space_id: int = Field( - ..., description="Search space ID to associate the config with" - ) - - -class NewLLMConfigUpdate(BaseModel): - """Schema for updating an existing NewLLMConfig. All fields are optional.""" - - name: str | None = Field(None, max_length=100) - description: str | None = Field(None, max_length=500) - - # Model Configuration - provider: LiteLLMProvider | None = None - custom_provider: str | None = Field(None, max_length=100) - model_name: str | None = Field(None, max_length=100) - api_key: str | None = None - api_base: str | None = Field(None, max_length=500) - litellm_params: dict[str, Any] | None = None - - # Prompt Configuration - system_instructions: str | None = None - use_default_system_instructions: bool | None = None - citations_enabled: bool | None = None - - -class NewLLMConfigRead(NewLLMConfigBase): - """Schema for reading a NewLLMConfig (includes id and timestamps).""" - - id: int - created_at: datetime - search_space_id: int - user_id: uuid.UUID - # Capability flag derived at the API boundary (no DB column). Default - # True matches the conservative-allow stance — a BYOK row that the - # route forgot to augment is not pre-judged. The streaming-task - # safety net is the only place a False actually blocks a request. - supports_image_input: bool = Field( - default=True, - description=( - "Whether the BYOK chat config can accept image inputs. Derived " - "at the route boundary from LiteLLM's authoritative model map " - "(``litellm.supports_vision``) — there is no DB column. " - "Default True is the conservative-allow stance for unknown / " - "unmapped models." - ), - ) - - model_config = ConfigDict(from_attributes=True) - - -class NewLLMConfigPublic(BaseModel): - """ - Public schema for NewLLMConfig that hides the API key. - Used when returning configs in list views or to users who shouldn't see keys. - """ - - id: int - name: str - description: str | None = None - - # Model Configuration (no api_key) - provider: LiteLLMProvider - custom_provider: str | None = None - model_name: str - api_base: str | None = None - litellm_params: dict[str, Any] | None = None - - # Prompt Configuration - system_instructions: str - use_default_system_instructions: bool - citations_enabled: bool - - created_at: datetime - search_space_id: int - user_id: uuid.UUID - # Capability flag derived at the API boundary (see NewLLMConfigRead). - supports_image_input: bool = Field( - default=True, - description=( - "Whether the BYOK chat config can accept image inputs. Derived " - "at the route boundary from LiteLLM's authoritative model map. " - "Default True is the conservative-allow stance." - ), - ) - - model_config = ConfigDict(from_attributes=True) - - -class DefaultSystemInstructionsResponse(BaseModel): - """Response schema for getting default system instructions.""" - - default_system_instructions: str = Field( - ..., description="The default SURFSENSE_SYSTEM_INSTRUCTIONS template" - ) - - -class GlobalNewLLMConfigRead(BaseModel): - """ - Schema for reading global LLM configs from YAML. - Global configs have negative IDs and no search_space_id. - API key is hidden for security. - - ID 0 is reserved for Auto mode which uses LiteLLM Router for load balancing. - """ - - id: int = Field( - ..., - description="Config ID: 0 for Auto mode, negative for global configs", - ) - name: str - description: str | None = None - - # Model Configuration (no api_key) - provider: str # String because YAML doesn't enforce enum, "AUTO" for Auto mode - custom_provider: str | None = None - model_name: str - api_base: str | None = None - litellm_params: dict[str, Any] | None = None - - # Prompt Configuration - system_instructions: str = "" - use_default_system_instructions: bool = True - citations_enabled: bool = True - - is_global: bool = True # Always true for global configs - is_auto_mode: bool = False # True only for Auto mode (ID 0) - - billing_tier: str = "free" - is_premium: bool = False - anonymous_enabled: bool = False - seo_enabled: bool = False - seo_slug: str | None = None - seo_title: str | None = None - seo_description: str | None = None - quota_reserve_tokens: int | None = None - supports_image_input: bool = Field( - default=True, - description=( - "Whether the model accepts image inputs (multimodal vision). " - "Derived server-side: OpenRouter dynamic configs use " - "``architecture.input_modalities``; YAML / BYOK use LiteLLM's " - "authoritative model map (``litellm.supports_vision``). The " - "new-chat selector hints with a 'No image' badge when this is " - "False and there are pending image attachments. The streaming " - "task fails fast only when LiteLLM *explicitly* marks a model " - "as text-only — unknown / unmapped models default-allow." - ), - ) - - -# ============================================================================= -# LLM Preferences Schemas (for role assignments) -# ============================================================================= - - -class LLMPreferencesRead(BaseModel): - """Schema for reading LLM preferences (role assignments) for a search space.""" - - agent_llm_id: int | None = Field( - None, description="ID of the LLM config to use for agent/chat tasks" - ) - image_generation_config_id: int | None = Field( - None, description="ID of the image generation config to use" - ) - vision_llm_config_id: int | None = Field( - None, - description="ID of the vision LLM config to use for vision/screenshot analysis", - ) - agent_llm: dict[str, Any] | None = Field( - None, description="Full config for chat model" - ) - image_generation_config: dict[str, Any] | None = Field( - None, description="Full config for image generation" - ) - vision_llm_config: dict[str, Any] | None = Field( - None, description="Full config for vision LLM" - ) - - model_config = ConfigDict(from_attributes=True) - - -class LLMPreferencesUpdate(BaseModel): - """Schema for updating LLM preferences.""" - - agent_llm_id: int | None = Field( - None, description="ID of the LLM config to use for agent/chat tasks" - ) - image_generation_config_id: int | None = Field( - None, description="ID of the image generation config to use" - ) - vision_llm_config_id: int | None = Field( - None, - description="ID of the vision LLM config to use for vision/screenshot analysis", - ) diff --git a/surfsense_backend/app/schemas/vision_llm.py b/surfsense_backend/app/schemas/vision_llm.py deleted file mode 100644 index d0eeaf5c6..000000000 --- a/surfsense_backend/app/schemas/vision_llm.py +++ /dev/null @@ -1,116 +0,0 @@ -import uuid -from datetime import datetime -from typing import Any - -from pydantic import BaseModel, ConfigDict, Field - -from app.db import VisionProvider - - -class VisionLLMConfigBase(BaseModel): - name: str = Field(..., max_length=100) - description: str | None = Field(None, max_length=500) - provider: VisionProvider = Field(...) - custom_provider: str | None = Field(None, max_length=100) - model_name: str = Field(..., max_length=100) - api_key: str = Field(...) - api_base: str | None = Field(None, max_length=500) - api_version: str | None = Field(None, max_length=50) - litellm_params: dict[str, Any] | None = Field(default=None) - - -class VisionLLMConfigCreate(VisionLLMConfigBase): - search_space_id: int = Field(...) - - -class VisionLLMConfigUpdate(BaseModel): - name: str | None = Field(None, max_length=100) - description: str | None = Field(None, max_length=500) - provider: VisionProvider | None = None - custom_provider: str | None = Field(None, max_length=100) - model_name: str | None = Field(None, max_length=100) - api_key: str | None = None - api_base: str | None = Field(None, max_length=500) - api_version: str | None = Field(None, max_length=50) - litellm_params: dict[str, Any] | None = None - - -class VisionLLMConfigRead(VisionLLMConfigBase): - id: int - created_at: datetime - search_space_id: int - user_id: uuid.UUID - - model_config = ConfigDict(from_attributes=True) - - -class VisionLLMConfigPublic(BaseModel): - id: int - name: str - description: str | None = None - provider: VisionProvider - custom_provider: str | None = None - model_name: str - api_base: str | None = None - api_version: str | None = None - litellm_params: dict[str, Any] | None = None - created_at: datetime - search_space_id: int - user_id: uuid.UUID - - model_config = ConfigDict(from_attributes=True) - - -class GlobalVisionLLMConfigRead(BaseModel): - """Schema for reading global vision LLM configs from YAML. - - The ``billing_tier`` field allows the frontend to show a Premium/Free - badge and (more importantly) tells the backend whether to debit the - user's premium credit pool when this config is used. ``"free"`` is - the default for backward compatibility — admins must explicitly opt - a global config into ``"premium"``. - """ - - id: int = Field(...) - name: str - description: str | None = None - provider: str - custom_provider: str | None = None - model_name: str - api_base: str | None = None - api_version: str | None = None - litellm_params: dict[str, Any] | None = None - is_global: bool = True - is_auto_mode: bool = False - billing_tier: str = Field( - default="free", - description="'free' or 'premium'. Premium debits the user's premium credit pool (USD-cost-based).", - ) - is_premium: bool = Field( - default=False, - description=( - "Convenience boolean derived server-side from " - "``billing_tier == 'premium'``. The new-chat model selector " - "keys its Free/Premium badge off this field for parity with " - "chat (`GlobalLLMConfigRead.is_premium`)." - ), - ) - quota_reserve_tokens: int | None = Field( - default=None, - description=( - "Optional override for the per-call reservation in *tokens* — " - "converted to micro-USD via the model's input/output prices at " - "reservation time. Falls back to QUOTA_DEFAULT_RESERVE_TOKENS." - ), - ) - input_cost_per_token: float | None = Field( - default=None, - description=( - "Optional input price in USD/token. Used by pricing_registration to " - "register custom Azure / OpenRouter aliases with LiteLLM at startup." - ), - ) - output_cost_per_token: float | None = Field( - default=None, - description="Optional output price in USD/token. Pair with input_cost_per_token.", - ) diff --git a/surfsense_backend/app/services/auto_model_pin_service.py b/surfsense_backend/app/services/auto_model_pin_service.py index b4f1bafc9..dfd7c7be3 100644 --- a/surfsense_backend/app/services/auto_model_pin_service.py +++ b/surfsense_backend/app/services/auto_model_pin_service.py @@ -1,13 +1,13 @@ -"""Resolve and persist Auto (Fastest) model pins per chat thread. +"""Resolve and persist Auto model pins per chat thread. -Auto (Fastest) is represented by ``agent_llm_id == 0``. For chat threads we -resolve that virtual mode to one concrete global LLM config exactly once and +Auto is represented by ``chat_model_id == 0``. For chat threads we +resolve that virtual mode to one concrete global model exactly once and persist the chosen config id on ``new_chat_threads.pinned_llm_config_id`` so subsequent turns are stable. Single-writer invariant: this module is the only writer of ``NewChatThread.pinned_llm_config_id`` (aside from the bulk clear in -``search_spaces_routes`` when a search space's ``agent_llm_id`` changes). +``model_connections_routes`` when a search space's ``chat_model_id`` changes). Therefore a non-NULL value unambiguously means "this thread has an Auto-resolved pin"; no separate source/policy column is needed. """ @@ -33,8 +33,10 @@ from app.services.token_quota_service import TokenQuotaService logger = logging.getLogger(__name__) -AUTO_FASTEST_ID = 0 -AUTO_FASTEST_MODE = "auto_fastest" +AUTO_MODE_ID = 0 +# Stable internal hash namespace for deterministic per-thread selection. +# Do not rename: changing this rebalances Auto's model choice for new pins. +AUTO_PIN_HASH_NAMESPACE = "auto_fastest" _RUNTIME_COOLDOWN_SECONDS = 600 _HEALTHY_TTL_SECONDS = 45 @@ -383,7 +385,7 @@ def _select_pin(eligible: list[dict], thread_id: int) -> tuple[dict, int]: pool = tier_a if tier_a else eligible pool = sorted(pool, key=lambda c: -int(c.get("quality_score") or 0)) top_k = pool[:_QUALITY_TOP_K] - digest = hashlib.sha256(f"{AUTO_FASTEST_MODE}:{thread_id}".encode()).digest() + digest = hashlib.sha256(f"{AUTO_PIN_HASH_NAMESPACE}:{thread_id}".encode()).digest() idx = int.from_bytes(digest[:8], "big") % len(top_k) return top_k[idx], len(top_k) @@ -425,7 +427,7 @@ async def resolve_or_get_pinned_llm_config_id( exclude_config_ids: set[int] | None = None, requires_image_input: bool = False, ) -> AutoPinResolution: - """Resolve Auto (Fastest) to one concrete config id and persist the pin. + """Resolve Auto to one concrete config id and persist the pin. For non-auto selections, this function clears any existing pin and returns the selected id as-is. @@ -457,7 +459,7 @@ async def resolve_or_get_pinned_llm_config_id( ) # Explicit model selected: clear any stale pin. - if selected_llm_config_id != AUTO_FASTEST_ID: + if selected_llm_config_id != AUTO_MODE_ID: if thread.pinned_llm_config_id is not None: thread.pinned_llm_config_id = None await session.commit() diff --git a/surfsense_backend/app/services/billable_calls.py b/surfsense_backend/app/services/billable_calls.py index f21f52e14..15a3c3e55 100644 --- a/surfsense_backend/app/services/billable_calls.py +++ b/surfsense_backend/app/services/billable_calls.py @@ -450,10 +450,10 @@ async def _resolve_agent_billing_for_search_space( Used by Celery tasks (podcast generation, video presentation) to bill the search-space owner's premium credit pool when the chat model is premium. - Resolution rules mirror chat at ``stream_new_chat.py:2294-2351``: + Resolution rules mirror the chat model role resolver: - - Search space not found / no ``agent_llm_id``: raise ``ValueError``. - - **Auto mode** (``id == AUTO_FASTEST_ID == 0``): + - Search space not found / no ``chat_model_id``: raise ``ValueError``. + - **Auto mode** (``id == AUTO_MODE_ID == 0``): * ``thread_id`` is set: delegate to ``resolve_or_get_pinned_llm_config_id`` (the same call chat uses) and recurse into the resolved id. Reuses chat's existing pin if present @@ -469,9 +469,8 @@ async def _resolve_agent_billing_for_search_space( (defaults to ``"free"`` via ``app/config/__init__.py:52`` setdefault), ``base_model = litellm_params.get("base_model") or model_name`` — NOT provider-prefixed, matching chat's cost-map lookup convention. - - **Positive id** (user BYOK ``NewLLMConfig``): always free (matches - ``AgentConfig.from_new_llm_config`` which hard-codes ``billing_tier="free"``); - ``base_model`` from ``litellm_params`` or ``model_name``. + - **Positive id** (user BYOK ``Model``): always free; ``base_model`` from + the model catalog override or the upstream ``model_id``. Note on imports: ``llm_service``, ``auto_model_pin_service``, and ``llm_router_service`` are imported lazily inside the function body to @@ -480,8 +479,9 @@ async def _resolve_agent_billing_for_search_space( ``billable_calls.py``'s module load path. """ from sqlalchemy import select + from sqlalchemy.orm import selectinload - from app.db import NewLLMConfig, SearchSpace + from app.db import Model, SearchSpace result = await session.execute( select(SearchSpace).where(SearchSpace.id == search_space_id) @@ -490,20 +490,20 @@ async def _resolve_agent_billing_for_search_space( if search_space is None: raise ValueError(f"Search space {search_space_id} not found") - agent_llm_id = search_space.agent_llm_id - if agent_llm_id is None: + chat_model_id = search_space.chat_model_id + if chat_model_id is None: raise ValueError( - f"Search space {search_space_id} has no agent_llm_id configured" + f"Search space {search_space_id} has no chat_model_id configured" ) owner_user_id: UUID = search_space.user_id from app.services.auto_model_pin_service import ( - AUTO_FASTEST_ID, + AUTO_MODE_ID, resolve_or_get_pinned_llm_config_id, ) - if agent_llm_id == AUTO_FASTEST_ID: + if chat_model_id == AUTO_MODE_ID: if thread_id is None: return owner_user_id, "free", "auto" try: @@ -512,7 +512,7 @@ async def _resolve_agent_billing_for_search_space( thread_id=thread_id, search_space_id=search_space_id, user_id=str(owner_user_id), - selected_llm_config_id=AUTO_FASTEST_ID, + selected_llm_config_id=AUTO_MODE_ID, ) except ValueError: logger.warning( @@ -523,28 +523,35 @@ async def _resolve_agent_billing_for_search_space( exc_info=True, ) return owner_user_id, "free", "auto" - agent_llm_id = resolution.resolved_llm_config_id + chat_model_id = resolution.resolved_llm_config_id - if agent_llm_id < 0: + if chat_model_id < 0: from app.services.llm_service import get_global_llm_config - cfg = get_global_llm_config(agent_llm_id) or {} + cfg = get_global_llm_config(chat_model_id) or {} billing_tier = str(cfg.get("billing_tier", "free")).lower() litellm_params = cfg.get("litellm_params") or {} base_model = litellm_params.get("base_model") or cfg.get("model_name") or "" return owner_user_id, billing_tier, base_model - nlc_result = await session.execute( - select(NewLLMConfig).where( - NewLLMConfig.id == agent_llm_id, - NewLLMConfig.search_space_id == search_space_id, - ) + model_result = await session.execute( + select(Model) + .options(selectinload(Model.connection)) + .where(Model.id == chat_model_id, Model.enabled.is_(True)) ) - nlc = nlc_result.scalars().first() + model = model_result.scalars().first() base_model = "" - if nlc is not None: - litellm_params = nlc.litellm_params or {} - base_model = litellm_params.get("base_model") or nlc.model_name or "" + if ( + model is not None + and model.connection is not None + and model.connection.enabled + and ( + model.connection.search_space_id in (None, search_space_id) + and model.connection.user_id in (None, owner_user_id) + ) + ): + catalog = model.catalog or {} + base_model = catalog.get("base_model") or model.model_id or "" return owner_user_id, "free", base_model diff --git a/surfsense_backend/app/services/llm_service.py b/surfsense_backend/app/services/llm_service.py index eadb4dbf8..277929e96 100644 --- a/surfsense_backend/app/services/llm_service.py +++ b/surfsense_backend/app/services/llm_service.py @@ -14,7 +14,11 @@ from app.services.auto_model_pin_service import ( auto_model_candidates, choose_auto_model_candidate, ) -from app.services.llm_router_service import AUTO_MODE_ID, ChatLiteLLMRouter, is_auto_mode +from app.services.llm_router_service import ( + AUTO_MODE_ID, + ChatLiteLLMRouter, + is_auto_mode, +) from app.services.model_capabilities import has_capability from app.services.model_resolver import native_connection_from_config, to_litellm from app.services.token_tracking_service import token_tracker @@ -96,26 +100,16 @@ class LLMRole: def get_global_llm_config(llm_config_id: int) -> dict | None: """ Get a global LLM configuration by ID. - Global configs have negative IDs. ID 0 is reserved for Auto mode. + Global configs have negative IDs. Auto mode (ID 0) is resolved through the + model-candidate pipeline, not this legacy config lookup. Args: - llm_config_id: The ID of the global config (should be negative or 0 for Auto) + llm_config_id: The ID of the global config (must be negative) Returns: dict: Global config dictionary or None if not found """ - # Auto mode (ID 0) is handled separately via the router - if llm_config_id == AUTO_MODE_ID: - return { - "id": AUTO_MODE_ID, - "name": "Auto (Fastest)", - "description": "Automatically routes requests across available LLM providers for optimal performance and rate limit handling", - "provider": "AUTO", - "model_name": "auto", - "is_auto_mode": True, - } - - if llm_config_id > 0: + if llm_config_id >= 0: return None for cfg in config.GLOBAL_LLM_CONFIGS: diff --git a/surfsense_backend/app/services/model_list_service.py b/surfsense_backend/app/services/model_list_service.py index 0761d7e4f..ffb430756 100644 --- a/surfsense_backend/app/services/model_list_service.py +++ b/surfsense_backend/app/services/model_list_service.py @@ -24,7 +24,7 @@ CACHE_TTL_SECONDS = 86400 # 24 hours _cache: list[dict] | None = None _cache_timestamp: float = 0 -# Maps OpenRouter provider slug → our LiteLLMProvider enum value. +# Maps OpenRouter provider slug to native LiteLLM provider prefixes. # Only providers where the model-name part (after the slash) can be # used directly with the native provider's litellm prefix are listed. # diff --git a/surfsense_backend/app/services/openrouter_integration_service.py b/surfsense_backend/app/services/openrouter_integration_service.py index fbb70eb5a..8f4c4cb5f 100644 --- a/surfsense_backend/app/services/openrouter_integration_service.py +++ b/surfsense_backend/app/services/openrouter_integration_service.py @@ -281,7 +281,7 @@ def _generate_configs( OpenRouter's own ``openrouter/free`` meta-router is filtered out upstream via ``_EXCLUDED_MODEL_IDS``; we don't expose a redundant auto-select layer - because our own Auto (Fastest) pin + 24 h refresh + repair logic already + because our own Auto pin + 24 h refresh + repair logic already cover the catalogue-churn case. """ id_offset: int = settings.get("id_offset", -10000) @@ -346,7 +346,7 @@ def _generate_configs( # ``"No endpoints found that support image input"``. "supports_image_input": bool(normalized.get("supports_image_input")), _OPENROUTER_DYNAMIC_MARKER: True, - # Auto (Fastest) ranking metadata. ``quality_score`` is initialised + # Auto ranking metadata. ``quality_score`` is initialised # to the static score and gets re-blended with health on the next # ``_enrich_health`` pass (synchronous on refresh, deferred on cold # start so startup latency is unchanged). @@ -361,11 +361,7 @@ def _generate_configs( return configs -# ID-offset bands used to keep dynamic OpenRouter configs in their own -# namespace per surface. Image / vision get separate bands so a single -# Postgres-INTEGER cfg ID is unambiguous about which selector it belongs to. _OPENROUTER_IMAGE_ID_OFFSET_DEFAULT = -20000 -_OPENROUTER_VISION_ID_OFFSET_DEFAULT = -30000 def _generate_image_gen_configs( @@ -431,89 +427,6 @@ def _generate_image_gen_configs( return configs -def _generate_vision_llm_configs( - raw_models: list[dict], settings: dict[str, Any] -) -> list[dict]: - """Convert OpenRouter vision-capable LLMs into global vision-LLM config - dicts (matches the YAML shape consumed by ``vision_llm_routes``). - - Filter: - - architecture.input_modalities contains "image" - - architecture.output_modalities contains "text" - - compatible provider (excluded slugs blocked) - - allowed model id (excluded list blocked) - - Vision-LLM is invoked from the indexer (image extraction during - document upload) via ``langchain_litellm.ChatLiteLLM.ainvoke``, so - the chat-only ``_supports_tool_calling`` and ``_has_sufficient_context`` - filters do not apply: a small-context vision model that doesn't - advertise tool-calling is still perfectly viable for "describe this - image" prompts. - """ - id_offset: int = int( - settings.get("vision_id_offset") or _OPENROUTER_VISION_ID_OFFSET_DEFAULT - ) - api_key: str = settings.get("api_key", "") - rpm: int = settings.get("rpm", 200) - tpm: int = settings.get("tpm", 1_000_000) - free_rpm: int = settings.get("free_rpm", 20) - free_tpm: int = settings.get("free_tpm", 100_000) - quota_reserve_tokens: int = settings.get("quota_reserve_tokens", 4000) - litellm_params: dict = settings.get("litellm_params") or {} - - vision_models = [ - m - for m in raw_models - if supports_image_input(m) - and _shared_is_compatible_provider(m) - and _shared_is_allowed_model(m) - and "/" in m.get("id", "") - ] - - configs: list[dict] = [] - taken: set[int] = set() - for model in vision_models: - model_id: str = model["id"] - name: str = model.get("name", model_id) - tier = _openrouter_tier(model) - pricing = model.get("pricing") or {} - - # Capture per-token prices so ``pricing_registration`` can - # register them with LiteLLM at startup (and so the cost - # estimator in ``estimate_call_reserve_micros`` can resolve - # them at reserve time). - try: - input_cost = float(pricing.get("prompt", 0) or 0) - except (TypeError, ValueError): - input_cost = 0.0 - try: - output_cost = float(pricing.get("completion", 0) or 0) - except (TypeError, ValueError): - output_cost = 0.0 - - cfg: dict[str, Any] = { - "id": _stable_config_id(model_id, id_offset, taken), - "name": name, - "description": f"{name} via OpenRouter (vision)", - "provider": "openrouter", - "model_name": model_id, - "api_key": api_key, - "api_base": "https://openrouter.ai/api/v1", - "api_version": None, - "rpm": free_rpm if tier == "free" else rpm, - "tpm": free_tpm if tier == "free" else tpm, - "litellm_params": dict(litellm_params), - "billing_tier": tier, - "quota_reserve_tokens": quota_reserve_tokens, - "input_cost_per_token": input_cost or None, - "output_cost_per_token": output_cost or None, - _OPENROUTER_DYNAMIC_MARKER: True, - } - configs.append(cfg) - - return configs - - class OpenRouterIntegrationService: """Singleton that manages the dynamic OpenRouter model catalogue.""" @@ -724,7 +637,7 @@ class OpenRouterIntegrationService: return counts # ------------------------------------------------------------------ - # Auto (Fastest) health enrichment + # Auto health enrichment # ------------------------------------------------------------------ async def _enrich_health_safely( diff --git a/surfsense_backend/app/services/pricing_registration.py b/surfsense_backend/app/services/pricing_registration.py index 9e4e3b552..7343df737 100644 --- a/surfsense_backend/app/services/pricing_registration.py +++ b/surfsense_backend/app/services/pricing_registration.py @@ -154,10 +154,8 @@ def _register_chat_shape_configs( input_cost = _safe_float(entry.get("prompt")) output_cost = _safe_float(entry.get("completion")) else: - # Vision configs from ``_generate_vision_llm_configs`` - # carry their pricing inline because the OpenRouter - # raw-pricing cache is keyed by chat-catalogue model_id; - # vision flows pick up the inline values here. + # Some dynamically materialized configs can carry pricing + # inline when the raw OpenRouter cache has no matching entry. input_cost = _safe_float(cfg.get("input_cost_per_token")) output_cost = _safe_float(cfg.get("output_cost_per_token")) if input_cost == 0.0 and output_cost == 0.0: diff --git a/surfsense_backend/app/services/quality_score.py b/surfsense_backend/app/services/quality_score.py index 9cc9c21ac..737dd7c2f 100644 --- a/surfsense_backend/app/services/quality_score.py +++ b/surfsense_backend/app/services/quality_score.py @@ -1,4 +1,4 @@ -"""Pure-function quality scoring for Auto (Fastest) model selection. +"""Pure-function quality scoring for Auto model selection. This module is import-free of any service / request-path dependencies. All numbers are computed once during the OpenRouter refresh tick (or YAML load) diff --git a/surfsense_backend/app/services/vision_llm_router_service.py b/surfsense_backend/app/services/vision_llm_router_service.py deleted file mode 100644 index 0ff716324..000000000 --- a/surfsense_backend/app/services/vision_llm_router_service.py +++ /dev/null @@ -1,160 +0,0 @@ -import logging -from typing import Any - -from litellm import Router - -from app.services.model_resolver import native_connection_from_config, to_litellm - -logger = logging.getLogger(__name__) - -VISION_AUTO_MODE_ID = 0 - -class VisionLLMRouterService: - _instance = None - _router: Router | None = None - _model_list: list[dict] = [] - _router_settings: dict = {} - _initialized: bool = False - - def __new__(cls): - if cls._instance is None: - cls._instance = super().__new__(cls) - return cls._instance - - @classmethod - def get_instance(cls) -> "VisionLLMRouterService": - if cls._instance is None: - cls._instance = cls() - return cls._instance - - @classmethod - def initialize( - cls, - global_configs: list[dict], - router_settings: dict | None = None, - ) -> None: - instance = cls.get_instance() - - if instance._initialized: - logger.debug("Vision LLM Router already initialized, skipping") - return - - model_list = [] - for config in global_configs: - deployment = cls._config_to_deployment(config) - if deployment: - model_list.append(deployment) - - if not model_list: - logger.warning( - "No valid vision LLM configs found for router initialization" - ) - return - - instance._model_list = model_list - instance._router_settings = router_settings or {} - - default_settings = { - "routing_strategy": "usage-based-routing", - "num_retries": 3, - "allowed_fails": 3, - "cooldown_time": 60, - "retry_after": 5, - } - - final_settings = {**default_settings, **instance._router_settings} - - try: - instance._router = Router( - model_list=model_list, - routing_strategy=final_settings.get( - "routing_strategy", "usage-based-routing" - ), - num_retries=final_settings.get("num_retries", 3), - allowed_fails=final_settings.get("allowed_fails", 3), - cooldown_time=final_settings.get("cooldown_time", 60), - set_verbose=False, - ) - instance._initialized = True - logger.info( - "Vision LLM Router initialized with %d deployments, strategy: %s", - len(model_list), - final_settings.get("routing_strategy"), - ) - except Exception as e: - logger.error(f"Failed to initialize Vision LLM Router: {e}") - instance._router = None - - @classmethod - def _config_to_deployment(cls, config: dict) -> dict | None: - try: - if not config.get("model_name") or not config.get("api_key"): - return None - - model_string, resolved_kwargs = to_litellm( - native_connection_from_config(config), - config["model_name"], - ) - litellm_params: dict[str, Any] = {"model": model_string, **resolved_kwargs} - - deployment: dict[str, Any] = { - "model_name": "auto", - "litellm_params": litellm_params, - } - - if config.get("rpm"): - deployment["rpm"] = config["rpm"] - if config.get("tpm"): - deployment["tpm"] = config["tpm"] - - return deployment - - except Exception as e: - logger.warning(f"Failed to convert vision config to deployment: {e}") - return None - - @classmethod - def get_router(cls) -> Router | None: - instance = cls.get_instance() - return instance._router - - @classmethod - def is_initialized(cls) -> bool: - instance = cls.get_instance() - return instance._initialized and instance._router is not None - - @classmethod - def get_model_count(cls) -> int: - instance = cls.get_instance() - return len(instance._model_list) - - -def is_vision_auto_mode(config_id: int | None) -> bool: - return config_id == VISION_AUTO_MODE_ID - - -def build_vision_model_string( - litellm_provider: str, model_name: str, custom_provider: str | None -) -> str: - if custom_provider: - return f"{custom_provider}/{model_name}" - return f"{litellm_provider}/{model_name}" - - -def get_global_vision_llm_config(config_id: int) -> dict | None: - from app.config import config - - if config_id == VISION_AUTO_MODE_ID: - return { - "id": VISION_AUTO_MODE_ID, - "name": "Auto (Fastest)", - "provider": "AUTO", - "model_name": "auto", - "is_auto_mode": True, - } - if config_id > 0: - return None - for cfg in config.GLOBAL_VISION_LLM_CONFIGS: - if cfg.get("id") == config_id: - return cfg - return None diff --git a/surfsense_backend/app/services/vision_model_list_service.py b/surfsense_backend/app/services/vision_model_list_service.py deleted file mode 100644 index 6eae8c455..000000000 --- a/surfsense_backend/app/services/vision_model_list_service.py +++ /dev/null @@ -1,134 +0,0 @@ -""" -Service for fetching and caching the vision-capable model list. - -Reuses the same OpenRouter public API and local fallback as the LLM model -list service, but filters for models that accept image input. -""" - -import json -import logging -import time -from pathlib import Path - -import httpx - -logger = logging.getLogger(__name__) - -OPENROUTER_API_URL = "https://openrouter.ai/api/v1/models" -FALLBACK_FILE = ( - Path(__file__).parent.parent / "config" / "vision_model_list_fallback.json" -) -CACHE_TTL_SECONDS = 86400 # 24 hours - -_cache: list[dict] | None = None -_cache_timestamp: float = 0 - -OPENROUTER_SLUG_TO_VISION_PROVIDER: dict[str, str] = { - "openai": "OPENAI", - "anthropic": "ANTHROPIC", - "google": "GOOGLE", - "mistralai": "MISTRAL", - "x-ai": "XAI", -} - - -def _format_context_length(length: int | None) -> str | None: - if not length: - return None - if length >= 1_000_000: - return f"{length / 1_000_000:g}M" - if length >= 1_000: - return f"{length / 1_000:g}K" - return str(length) - - -async def _fetch_from_openrouter() -> list[dict] | None: - try: - async with httpx.AsyncClient(timeout=15) as client: - response = await client.get(OPENROUTER_API_URL) - response.raise_for_status() - data = response.json() - return data.get("data", []) - except Exception as e: - logger.warning("Failed to fetch from OpenRouter API for vision models: %s", e) - return None - - -def _load_fallback() -> list[dict]: - try: - with open(FALLBACK_FILE, encoding="utf-8") as f: - return json.load(f) - except Exception as e: - logger.error("Failed to load vision model fallback list: %s", e) - return [] - - -def _is_vision_model(model: dict) -> bool: - """Return True if the model accepts image input and outputs text.""" - arch = model.get("architecture", {}) - input_mods = arch.get("input_modalities", []) - output_mods = arch.get("output_modalities", []) - return "image" in input_mods and "text" in output_mods - - -def _process_vision_models(raw_models: list[dict]) -> list[dict]: - processed: list[dict] = [] - - for model in raw_models: - model_id: str = model.get("id", "") - name: str = model.get("name", "") - context_length = model.get("context_length") - - if "/" not in model_id: - continue - - if not _is_vision_model(model): - continue - - provider_slug, model_name = model_id.split("/", 1) - context_window = _format_context_length(context_length) - - processed.append( - { - "value": model_id, - "label": name, - "provider": "OPENROUTER", - "context_window": context_window, - } - ) - - direct_provider = OPENROUTER_SLUG_TO_VISION_PROVIDER.get(provider_slug) - if direct_provider: - if direct_provider == "GOOGLE" and not model_name.startswith("gemini-"): - continue - - processed.append( - { - "value": model_name, - "label": name, - "provider": direct_provider, - "context_window": context_window, - } - ) - - return processed - - -async def get_vision_model_list() -> list[dict]: - global _cache, _cache_timestamp - - if _cache is not None and (time.time() - _cache_timestamp) < CACHE_TTL_SECONDS: - return _cache - - raw_models = await _fetch_from_openrouter() - - if raw_models is None: - logger.info("Using fallback vision model list") - return _load_fallback() - - processed = _process_vision_models(raw_models) - - _cache = processed - _cache_timestamp = time.time() - - return processed diff --git a/surfsense_backend/scripts/verify_chat_image_capability.py b/surfsense_backend/scripts/verify_chat_image_capability.py index 6e711f99a..e6a535711 100644 --- a/surfsense_backend/scripts/verify_chat_image_capability.py +++ b/surfsense_backend/scripts/verify_chat_image_capability.py @@ -330,31 +330,6 @@ async def probe_chat_configs(report: Report, *, live: bool) -> None: report.add(result) -async def probe_vision_configs(report: Report, *, live: bool) -> None: - print("\n[vision configs from global_vision_llm_configs (YAML-static)]") - for cfg in config.GLOBAL_VISION_LLM_CONFIGS: - if _is_or_dynamic(cfg): - continue - result = ProbeResult( - label=str(cfg.get("name") or cfg.get("model_name")), - surface="vision", - config_id=cfg.get("id"), - ) - # For vision configs, capability is implied — they're in the - # dedicated vision pool. Run the same resolver to flag any - # surprise disagreement. - cap_ok, cap_note = _probe_chat_capability(cfg) - result.capability_ok = cap_ok - result.capability_note = cap_note - if live: - t0 = time.perf_counter() - ok, note = await _live_chat_image_call(cfg) - result.live_ok = ok - result.live_note = note - result.duration_s = time.perf_counter() - t0 - report.add(result) - - async def probe_image_gen_configs(report: Report, *, live: bool) -> None: print( "\n[image generation configs from global_image_generation_configs (YAML-static)]" @@ -380,7 +355,7 @@ async def probe_image_gen_configs(report: Report, *, live: bool) -> None: async def probe_openrouter_catalog(report: Report, *, live: bool) -> None: - """Sample one chat (vision-capable), one vision, one image-gen model + """Sample chat/vision-capable and image-gen models from the live OpenRouter catalogue. Doesn't iterate the full pool (would be hundreds of probes); just validates the integration end- to-end on a representative model from each surface.""" @@ -405,9 +380,6 @@ async def probe_openrouter_catalog(report: Report, *, live: bool) -> None: for c in config.GLOBAL_LLM_CONFIGS if c.get("provider") == "OPENROUTER" and c.get("supports_image_input") ] - or_vision = [ - c for c in config.GLOBAL_VISION_LLM_CONFIGS if c.get("provider") == "OPENROUTER" - ] or_image_gen = [ c for c in config.GLOBAL_IMAGE_GEN_CONFIGS if c.get("provider") == "OPENROUTER" ] @@ -427,11 +399,6 @@ async def probe_openrouter_catalog(report: Report, *, live: bool) -> None: ("or-chat", _pick_first(or_chat, "anthropic/claude")), ("or-chat", _pick_first(or_chat, "google/gemini-2.5-flash")), ] - vision_picks = [ - ("or-vision", _pick_first(or_vision, "openai/gpt-4o")), - ("or-vision", _pick_first(or_vision, "anthropic/claude")), - ("or-vision", _pick_first(or_vision, "google/gemini-2.5-flash")), - ] image_picks = [ ("or-image", _pick_first(or_image_gen, "google/gemini-2.5-flash-image")), # OpenRouter publishes OpenAI image models as ``openai/gpt-5-image*`` @@ -441,11 +408,11 @@ async def probe_openrouter_catalog(report: Report, *, live: bool) -> None: ] print( - f" catalog: chat={len(or_chat)} vision={len(or_vision)} image_gen={len(or_image_gen)} " + f" catalog: chat_vision={len(or_chat)} image_gen={len(or_image_gen)} " f"(service initialized={service.is_initialized() if hasattr(service, 'is_initialized') else 'n/a'})" ) - for surface, picked in chat_picks + vision_picks + image_picks: + for surface, picked in chat_picks + image_picks: if not picked: report.add( ProbeResult( @@ -486,7 +453,6 @@ async def probe_openrouter_catalog(report: Report, *, live: bool) -> None: async def main(args: argparse.Namespace) -> int: print("Loaded global configs:") print(f" chat: {len(config.GLOBAL_LLM_CONFIGS)} entries") - print(f" vision: {len(config.GLOBAL_VISION_LLM_CONFIGS)} entries") print(f" image-gen: {len(config.GLOBAL_IMAGE_GEN_CONFIGS)} entries") print(f" OR settings present: {bool(config.OPENROUTER_INTEGRATION_SETTINGS)}") @@ -507,8 +473,6 @@ async def main(args: argparse.Namespace) -> int: report = Report() if not args.skip_chat: await probe_chat_configs(report, live=args.live) - if not args.skip_vision: - await probe_vision_configs(report, live=args.live) if not args.skip_image_gen: await probe_image_gen_configs(report, live=args.live) if not args.skip_openrouter: @@ -528,7 +492,6 @@ def _parse_args() -> argparse.Namespace: ) parser.set_defaults(live=True) parser.add_argument("--skip-chat", action="store_true") - parser.add_argument("--skip-vision", action="store_true") parser.add_argument("--skip-image-gen", action="store_true") parser.add_argument("--skip-openrouter", action="store_true") return parser.parse_args() diff --git a/surfsense_backend/tests/unit/automations/actions/builtin/agent_task/test_dependencies.py b/surfsense_backend/tests/unit/automations/actions/builtin/agent_task/test_dependencies.py index 79da12933..f5709e517 100644 --- a/surfsense_backend/tests/unit/automations/actions/builtin/agent_task/test_dependencies.py +++ b/surfsense_backend/tests/unit/automations/actions/builtin/agent_task/test_dependencies.py @@ -1,6 +1,6 @@ """Lock the runtime model-policy backstop in ``build_dependencies``. -Automations resolve their LLM from the *captured* ``agent_llm_id`` snapshot (so +Automations resolve their LLM from the *captured* ``chat_model_id`` snapshot (so runs are insulated from later chat/search-space model changes), and the model policy is re-checked at run time so a captured model that is no longer billable fails the run clearly. When no snapshot is present, resolution falls back to the @@ -45,10 +45,10 @@ def patched_side_effects(monkeypatch: pytest.MonkeyPatch): return None -async def test_build_dependencies_resolves_captured_agent_llm_id( +async def test_build_dependencies_resolves_captured_chat_model_id( monkeypatch: pytest.MonkeyPatch, patched_side_effects ) -> None: - """The bundle loads with the *captured* ``agent_llm_id``, not the live search space.""" + """The bundle loads with the *captured* ``chat_model_id``, not the live search space.""" captured: dict[str, Any] = {} async def _fake_load(_session, *, config_id, search_space_id): @@ -67,13 +67,13 @@ async def test_build_dependencies_resolves_captured_agent_llm_id( lambda _ss: pytest.fail("search-space policy should not run on captured path"), ) - search_space = SimpleNamespace(agent_llm_id=-99) + search_space = SimpleNamespace(chat_model_id=-99) result = await build_dependencies( session=_FakeSession(search_space), search_space_id=42, - agent_llm_id=-7, - image_generation_config_id=5, - vision_llm_config_id=-1, + chat_model_id=-7, + image_gen_model_id=5, + vision_model_id=-1, ) assert captured == {"config_id": -7, "search_space_id": 42} @@ -98,17 +98,17 @@ async def test_build_dependencies_validates_captured_ids( monkeypatch.setattr(deps_mod, "load_llm_bundle", _fake_load) await build_dependencies( - session=_FakeSession(SimpleNamespace(agent_llm_id=0)), + session=_FakeSession(SimpleNamespace(chat_model_id=0)), search_space_id=42, - agent_llm_id=-7, - image_generation_config_id=5, - vision_llm_config_id=-1, + chat_model_id=-7, + image_gen_model_id=5, + vision_model_id=-1, ) assert seen == { - "agent_llm_id": -7, - "image_generation_config_id": 5, - "vision_llm_config_id": -1, + "chat_model_id": -7, + "image_gen_model_id": 5, + "vision_model_id": -1, } @@ -119,7 +119,7 @@ async def test_build_dependencies_raises_on_captured_policy_violation( def _raise(**_kw): raise AutomationModelPolicyError( - [{"kind": "image", "config_id": -2, "reason": "free model"}] + [{"kind": "image", "model_id": -2, "reason": "free model"}] ) monkeypatch.setattr(deps_mod, "assert_models_billable", _raise) @@ -131,11 +131,11 @@ async def test_build_dependencies_raises_on_captured_policy_violation( with pytest.raises(DependencyError): await build_dependencies( - session=_FakeSession(SimpleNamespace(agent_llm_id=-7)), + session=_FakeSession(SimpleNamespace(chat_model_id=-7)), search_space_id=42, - agent_llm_id=-7, - image_generation_config_id=-2, - vision_llm_config_id=-1, + chat_model_id=-7, + image_gen_model_id=-2, + vision_model_id=-1, ) @@ -157,7 +157,7 @@ async def test_build_dependencies_falls_back_to_search_space( lambda **_kw: pytest.fail("captured policy should not run on fallback path"), ) - search_space = SimpleNamespace(agent_llm_id=-7) + search_space = SimpleNamespace(chat_model_id=-7) result = await build_dependencies( session=_FakeSession(search_space), search_space_id=42 ) diff --git a/surfsense_backend/tests/unit/automations/runtime/test_executor_action_ctx.py b/surfsense_backend/tests/unit/automations/runtime/test_executor_action_ctx.py index d7e3c4a0c..c89624fbf 100644 --- a/surfsense_backend/tests/unit/automations/runtime/test_executor_action_ctx.py +++ b/surfsense_backend/tests/unit/automations/runtime/test_executor_action_ctx.py @@ -28,9 +28,9 @@ def _run() -> SimpleNamespace: def test_build_action_ctx_propagates_captured_models() -> None: """``definition.models`` flows onto the ActionContext model fields.""" models = AutomationModels( - agent_llm_id=-1, - image_generation_config_id=5, - vision_llm_config_id=-1, + chat_model_id=-1, + image_gen_model_id=5, + vision_model_id=-1, ) ctx = _build_action_ctx( cast(AsyncSession, None), @@ -40,9 +40,9 @@ def test_build_action_ctx_propagates_captured_models() -> None: ) assert ctx.search_space_id == 42 - assert ctx.agent_llm_id == -1 - assert ctx.image_generation_config_id == 5 - assert ctx.vision_llm_config_id == -1 + assert ctx.chat_model_id == -1 + assert ctx.image_gen_model_id == 5 + assert ctx.vision_model_id == -1 def test_build_action_ctx_none_models_leaves_fields_none() -> None: @@ -54,6 +54,6 @@ def test_build_action_ctx_none_models_leaves_fields_none() -> None: None, ) - assert ctx.agent_llm_id is None - assert ctx.image_generation_config_id is None - assert ctx.vision_llm_config_id is None + assert ctx.chat_model_id is None + assert ctx.image_gen_model_id is None + assert ctx.vision_model_id is None diff --git a/surfsense_backend/tests/unit/automations/schemas/definition/test_envelope.py b/surfsense_backend/tests/unit/automations/schemas/definition/test_envelope.py index 25e193ffa..dc7221b11 100644 --- a/surfsense_backend/tests/unit/automations/schemas/definition/test_envelope.py +++ b/surfsense_backend/tests/unit/automations/schemas/definition/test_envelope.py @@ -40,24 +40,24 @@ def test_automation_definition_models_round_trip() -> None: name="Daily digest", plan=[PlanStep(step_id="s1", action="agent_task")], models=AutomationModels( - agent_llm_id=-1, - image_generation_config_id=5, - vision_llm_config_id=-1, + chat_model_id=-1, + image_gen_model_id=5, + vision_model_id=-1, ), ) dumped = definition.model_dump(mode="json", by_alias=True) assert dumped["models"] == { - "agent_llm_id": -1, - "image_generation_config_id": 5, - "vision_llm_config_id": -1, + "chat_model_id": -1, + "image_gen_model_id": 5, + "vision_model_id": -1, } restored = AutomationDefinition.model_validate(dumped) assert restored.models is not None - assert restored.models.agent_llm_id == -1 - assert restored.models.image_generation_config_id == 5 - assert restored.models.vision_llm_config_id == -1 + assert restored.models.chat_model_id == -1 + assert restored.models.image_gen_model_id == 5 + assert restored.models.vision_model_id == -1 def test_automation_definition_rejects_unknown_top_level_field() -> None: diff --git a/surfsense_backend/tests/unit/automations/services/test_automation_service_policy.py b/surfsense_backend/tests/unit/automations/services/test_automation_service_policy.py index 0bbff39dc..c97dec6a2 100644 --- a/surfsense_backend/tests/unit/automations/services/test_automation_service_policy.py +++ b/surfsense_backend/tests/unit/automations/services/test_automation_service_policy.py @@ -64,12 +64,12 @@ async def test_assert_models_billable_raises_422_on_violation( def _raise(_ss): raise AutomationModelPolicyError( - [{"kind": "llm", "config_id": 0, "reason": "Auto mode"}] + [{"kind": "llm", "model_id": 0, "reason": "Auto mode"}] ) monkeypatch.setattr(automation_mod, "assert_automation_models_billable", _raise) - service = _service(SimpleNamespace(agent_llm_id=0)) + service = _service(SimpleNamespace(chat_model_id=0)) with pytest.raises(HTTPException) as exc_info: await service._assert_models_billable(1) @@ -99,7 +99,7 @@ async def test_assert_models_billable_returns_search_space_when_ok( automation_mod, "assert_automation_models_billable", lambda _ss: None ) - search_space = SimpleNamespace(agent_llm_id=-1) + search_space = SimpleNamespace(chat_model_id=-1) service = _service(search_space) assert await service._assert_models_billable(1) is search_space @@ -123,9 +123,9 @@ async def test_create_injects_captured_models_from_search_space( monkeypatch.setattr(AutomationService, "_get_with_triggers_or_raise", _return_added) search_space = SimpleNamespace( - agent_llm_id=-1, - image_generation_config_id=5, - vision_llm_config_id=-1, + chat_model_id=-1, + image_gen_model_id=5, + vision_model_id=-1, ) service = _service(search_space) payload = AutomationCreate( @@ -137,9 +137,9 @@ async def test_create_injects_captured_models_from_search_space( automation = await service.create(payload) assert automation.definition["models"] == { - "agent_llm_id": -1, - "image_generation_config_id": 5, - "vision_llm_config_id": -1, + "chat_model_id": -1, + "image_gen_model_id": 5, + "vision_model_id": -1, } @@ -162,9 +162,9 @@ async def test_create_treats_unset_prefs_as_auto_zero( monkeypatch.setattr(AutomationService, "_get_with_triggers_or_raise", _return_added) search_space = SimpleNamespace( - agent_llm_id=None, - image_generation_config_id=None, - vision_llm_config_id=None, + chat_model_id=None, + image_gen_model_id=None, + vision_model_id=None, ) service = _service(search_space) payload = AutomationCreate(search_space_id=1, name="A", definition=_definition()) @@ -172,9 +172,9 @@ async def test_create_treats_unset_prefs_as_auto_zero( automation = await service.create(payload) assert automation.definition["models"] == { - "agent_llm_id": 0, - "image_generation_config_id": 0, - "vision_llm_config_id": 0, + "chat_model_id": 0, + "image_gen_model_id": 0, + "vision_model_id": 0, } @@ -195,11 +195,11 @@ async def test_create_honors_selected_models_when_provided( ) validated: dict[str, Any] = {} - def _assert_ok(*, agent_llm_id, image_generation_config_id, vision_llm_config_id): + def _assert_ok(*, chat_model_id, image_gen_model_id, vision_model_id): validated["ids"] = ( - agent_llm_id, - image_generation_config_id, - vision_llm_config_id, + chat_model_id, + image_gen_model_id, + vision_model_id, ) monkeypatch.setattr(automation_mod, "assert_models_billable", _assert_ok) @@ -213,15 +213,15 @@ async def test_create_honors_selected_models_when_provided( monkeypatch.setattr(AutomationService, "_authorize", _noop_authorize) monkeypatch.setattr(AutomationService, "_get_with_triggers_or_raise", _return_added) - service = _service(SimpleNamespace(agent_llm_id=-99)) + service = _service(SimpleNamespace(chat_model_id=-99)) payload = AutomationCreate( search_space_id=1, name="A", definition=_definition( models=AutomationModels( - agent_llm_id=-1, - image_generation_config_id=7, - vision_llm_config_id=-2, + chat_model_id=-1, + image_gen_model_id=7, + vision_model_id=-2, ) ), ) @@ -230,9 +230,9 @@ async def test_create_honors_selected_models_when_provided( assert validated["ids"] == (-1, 7, -2) assert automation.definition["models"] == { - "agent_llm_id": -1, - "image_generation_config_id": 7, - "vision_llm_config_id": -2, + "chat_model_id": -1, + "image_gen_model_id": 7, + "vision_model_id": -2, } @@ -241,9 +241,9 @@ async def test_create_rejects_unbillable_selected_models( ) -> None: """A non-billable explicit selection maps the policy error to HTTP 422.""" - def _raise(*, agent_llm_id, image_generation_config_id, vision_llm_config_id): + def _raise(*, chat_model_id, image_gen_model_id, vision_model_id): raise AutomationModelPolicyError( - [{"kind": "llm", "config_id": -3, "reason": "free model"}] + [{"kind": "llm", "model_id": -3, "reason": "free model"}] ) monkeypatch.setattr(automation_mod, "assert_models_billable", _raise) @@ -253,15 +253,15 @@ async def test_create_rejects_unbillable_selected_models( monkeypatch.setattr(AutomationService, "_authorize", _noop_authorize) - service = _service(SimpleNamespace(agent_llm_id=-3)) + service = _service(SimpleNamespace(chat_model_id=-3)) payload = AutomationCreate( search_space_id=1, name="A", definition=_definition( models=AutomationModels( - agent_llm_id=-3, - image_generation_config_id=7, - vision_llm_config_id=-2, + chat_model_id=-3, + image_gen_model_id=7, + vision_model_id=-2, ) ), ) @@ -277,9 +277,9 @@ async def test_update_preserves_captured_models( ) -> None: """A definition edit carries over the previously captured ``models``.""" captured = { - "agent_llm_id": -1, - "image_generation_config_id": 5, - "vision_llm_config_id": -1, + "chat_model_id": -1, + "image_gen_model_id": 5, + "vision_model_id": -1, } existing = SimpleNamespace( search_space_id=1, @@ -318,20 +318,20 @@ async def test_update_honors_changed_models_when_valid( "name": "A", "plan": [], "models": { - "agent_llm_id": -1, - "image_generation_config_id": 5, - "vision_llm_config_id": -1, + "chat_model_id": -1, + "image_gen_model_id": 5, + "vision_model_id": -1, }, }, version=3, ) validated: dict[str, Any] = {} - def _assert_ok(*, agent_llm_id, image_generation_config_id, vision_llm_config_id): + def _assert_ok(*, chat_model_id, image_gen_model_id, vision_model_id): validated["ids"] = ( - agent_llm_id, - image_generation_config_id, - vision_llm_config_id, + chat_model_id, + image_gen_model_id, + vision_model_id, ) monkeypatch.setattr(automation_mod, "assert_models_billable", _assert_ok) @@ -351,9 +351,9 @@ async def test_update_honors_changed_models_when_valid( patch = AutomationUpdate( definition=_definition( models=AutomationModels( - agent_llm_id=-2, - image_generation_config_id=9, - vision_llm_config_id=-2, + chat_model_id=-2, + image_gen_model_id=9, + vision_model_id=-2, ) ) ) @@ -362,9 +362,9 @@ async def test_update_honors_changed_models_when_valid( assert validated["ids"] == (-2, 9, -2) assert result.definition["models"] == { - "agent_llm_id": -2, - "image_generation_config_id": 9, - "vision_llm_config_id": -2, + "chat_model_id": -2, + "image_gen_model_id": 9, + "vision_model_id": -2, } assert result.version == 4 @@ -379,17 +379,17 @@ async def test_update_rejects_changed_unbillable_models( "name": "A", "plan": [], "models": { - "agent_llm_id": -1, - "image_generation_config_id": 5, - "vision_llm_config_id": -1, + "chat_model_id": -1, + "image_gen_model_id": 5, + "vision_model_id": -1, }, }, version=3, ) - def _raise(*, agent_llm_id, image_generation_config_id, vision_llm_config_id): + def _raise(*, chat_model_id, image_gen_model_id, vision_model_id): raise AutomationModelPolicyError( - [{"kind": "llm", "config_id": -7, "reason": "free model"}] + [{"kind": "llm", "model_id": -7, "reason": "free model"}] ) monkeypatch.setattr(automation_mod, "assert_models_billable", _raise) @@ -409,9 +409,9 @@ async def test_update_rejects_changed_unbillable_models( patch = AutomationUpdate( definition=_definition( models=AutomationModels( - agent_llm_id=-7, - image_generation_config_id=5, - vision_llm_config_id=-1, + chat_model_id=-7, + image_gen_model_id=5, + vision_model_id=-1, ) ) ) @@ -431,9 +431,9 @@ async def test_update_keeps_unchanged_models_without_revalidation( premium without an unrelated edit tripping the policy check. """ captured = { - "agent_llm_id": -1, - "image_generation_config_id": 5, - "vision_llm_config_id": -1, + "chat_model_id": -1, + "image_gen_model_id": 5, + "vision_model_id": -1, } existing = SimpleNamespace( search_space_id=1, @@ -485,7 +485,7 @@ async def test_model_eligibility_authorizes_and_returns_payload( lambda _ss: {"allowed": False, "violations": [{"kind": "image"}]}, ) - service = _service(SimpleNamespace(agent_llm_id=-2)) + service = _service(SimpleNamespace(chat_model_id=-2)) result = await service.model_eligibility(search_space_id=5) assert result == {"allowed": False, "violations": [{"kind": "image"}]} diff --git a/surfsense_backend/tests/unit/automations/services/test_model_policy.py b/surfsense_backend/tests/unit/automations/services/test_model_policy.py index 8e0806151..574f6d9fd 100644 --- a/surfsense_backend/tests/unit/automations/services/test_model_policy.py +++ b/surfsense_backend/tests/unit/automations/services/test_model_policy.py @@ -27,9 +27,9 @@ pytestmark = pytest.mark.unit def _search_space(*, llm: int | None, image: int | None, vision: int | None): """Minimal stand-in for the ``SearchSpace`` ORM row the policy reads.""" return SimpleNamespace( - agent_llm_id=llm, - image_generation_config_id=image, - vision_llm_config_id=vision, + chat_model_id=llm, + image_gen_model_id=image, + vision_model_id=vision, ) @@ -39,29 +39,11 @@ def patched_globals(monkeypatch: pytest.MonkeyPatch): Negative ids: -1 is premium, -2 is free, for each of llm/image/vision. """ - llm_configs = { - -1: {"id": -1, "billing_tier": "premium"}, - -2: {"id": -2, "billing_tier": "free"}, - } - monkeypatch.setattr( - "app.agents.chat.runtime.llm_config.load_global_llm_config_by_id", - lambda cid: llm_configs.get(cid), - ) - from app.config import config as app_config monkeypatch.setattr( app_config, - "GLOBAL_IMAGE_GEN_CONFIGS", - [ - {"id": -1, "billing_tier": "premium"}, - {"id": -2, "billing_tier": "free"}, - ], - raising=False, - ) - monkeypatch.setattr( - app_config, - "GLOBAL_VISION_LLM_CONFIGS", + "GLOBAL_MODELS", [ {"id": -1, "billing_tier": "premium"}, {"id": -2, "billing_tier": "free"}, @@ -71,7 +53,7 @@ def patched_globals(monkeypatch: pytest.MonkeyPatch): return None -@pytest.mark.parametrize("kind", ["llm", "image", "vision"]) +@pytest.mark.parametrize("kind", ["chat", "image", "vision"]) def test_byok_positive_id_is_allowed(kind: str, patched_globals) -> None: """A positive config id is a user-owned BYOK model — always billable.""" allowed, reason = model_policy._classify(kind, 7) @@ -79,7 +61,7 @@ def test_byok_positive_id_is_allowed(kind: str, patched_globals) -> None: assert reason == "" -@pytest.mark.parametrize("kind", ["llm", "image", "vision"]) +@pytest.mark.parametrize("kind", ["chat", "image", "vision"]) @pytest.mark.parametrize("config_id", [0, None]) def test_auto_mode_is_blocked(kind: str, config_id, patched_globals) -> None: """Auto mode (id 0) and an unset slot (None) are blocked.""" @@ -88,7 +70,7 @@ def test_auto_mode_is_blocked(kind: str, config_id, patched_globals) -> None: assert "Auto mode" in reason -@pytest.mark.parametrize("kind", ["llm", "image", "vision"]) +@pytest.mark.parametrize("kind", ["chat", "image", "vision"]) def test_premium_global_is_allowed(kind: str, patched_globals) -> None: """A negative (global) id with premium billing tier is allowed.""" allowed, reason = model_policy._classify(kind, -1) @@ -96,7 +78,7 @@ def test_premium_global_is_allowed(kind: str, patched_globals) -> None: assert reason == "" -@pytest.mark.parametrize("kind", ["llm", "image", "vision"]) +@pytest.mark.parametrize("kind", ["chat", "image", "vision"]) def test_free_global_is_blocked(kind: str, patched_globals) -> None: """A negative (global) id with a free billing tier is blocked.""" allowed, reason = model_policy._classify(kind, -2) @@ -104,7 +86,7 @@ def test_free_global_is_blocked(kind: str, patched_globals) -> None: assert "free model" in reason -@pytest.mark.parametrize("kind", ["llm", "image", "vision"]) +@pytest.mark.parametrize("kind", ["chat", "image", "vision"]) def test_unknown_global_id_is_blocked(kind: str, patched_globals) -> None: """A negative id that resolves to no config is treated as not premium.""" allowed, _ = model_policy._classify(kind, -999) @@ -125,10 +107,10 @@ def test_eligibility_reports_each_violation(patched_globals) -> None: assert result["allowed"] is False kinds = {v["kind"] for v in result["violations"]} - assert kinds == {"llm", "image", "vision"} - # config_id is echoed back for the UI / settings deep-link. - by_kind = {v["kind"]: v["config_id"] for v in result["violations"]} - assert by_kind == {"llm": -2, "image": 0, "vision": -2} + assert kinds == {"chat", "image", "vision"} + # model_id is echoed back for the UI / settings deep-link. + by_kind = {v["kind"]: v["model_id"] for v in result["violations"]} + assert by_kind == {"chat": -2, "image": 0, "vision": -2} def test_assert_raises_with_violations(patched_globals) -> None: @@ -138,7 +120,7 @@ def test_assert_raises_with_violations(patched_globals) -> None: assert_automation_models_billable(search_space) assert len(exc_info.value.violations) == 1 - assert exc_info.value.violations[0]["kind"] == "llm" + assert exc_info.value.violations[0]["kind"] == "chat" def test_assert_passes_when_all_billable(patched_globals) -> None: @@ -153,7 +135,7 @@ def test_assert_passes_when_all_billable(patched_globals) -> None: def test_get_model_eligibility_all_billable(patched_globals) -> None: """Premium LLM + BYOK image + premium vision (explicit ids) → allowed.""" result = get_model_eligibility( - agent_llm_id=-1, image_generation_config_id=5, vision_llm_config_id=-1 + chat_model_id=-1, image_gen_model_id=5, vision_model_id=-1 ) assert result == {"allowed": True, "violations": []} @@ -161,28 +143,28 @@ def test_get_model_eligibility_all_billable(patched_globals) -> None: def test_get_model_eligibility_reports_each_violation(patched_globals) -> None: """Free LLM, Auto image, free vision (explicit ids) each produce a violation.""" result = get_model_eligibility( - agent_llm_id=-2, image_generation_config_id=0, vision_llm_config_id=-2 + chat_model_id=-2, image_gen_model_id=0, vision_model_id=-2 ) assert result["allowed"] is False - by_kind = {v["kind"]: v["config_id"] for v in result["violations"]} - assert by_kind == {"llm": -2, "image": 0, "vision": -2} + by_kind = {v["kind"]: v["model_id"] for v in result["violations"]} + assert by_kind == {"chat": -2, "image": 0, "vision": -2} def test_assert_models_billable_raises(patched_globals) -> None: """``assert_models_billable`` raises when any explicit id is blocked.""" with pytest.raises(AutomationModelPolicyError) as exc_info: assert_models_billable( - agent_llm_id=0, image_generation_config_id=5, vision_llm_config_id=-1 + chat_model_id=0, image_gen_model_id=5, vision_model_id=-1 ) assert len(exc_info.value.violations) == 1 - assert exc_info.value.violations[0]["kind"] == "llm" + assert exc_info.value.violations[0]["kind"] == "chat" def test_assert_models_billable_passes(patched_globals) -> None: """No exception when every explicit id is premium or BYOK.""" assert ( assert_models_billable( - agent_llm_id=3, image_generation_config_id=-1, vision_llm_config_id=4 + chat_model_id=3, image_gen_model_id=-1, vision_model_id=4 ) is None ) @@ -192,5 +174,5 @@ def test_search_space_wrapper_delegates_to_core(patched_globals) -> None: """The search-space wrapper produces the same result as the ID core.""" search_space = _search_space(llm=-2, image=0, vision=-2) assert get_automation_model_eligibility(search_space) == get_model_eligibility( - agent_llm_id=-2, image_generation_config_id=0, vision_llm_config_id=-2 + chat_model_id=-2, image_gen_model_id=0, vision_model_id=-2 ) diff --git a/surfsense_backend/tests/unit/routes/test_byok_supports_image_input.py b/surfsense_backend/tests/unit/routes/test_byok_supports_image_input.py deleted file mode 100644 index c9f18d77d..000000000 --- a/surfsense_backend/tests/unit/routes/test_byok_supports_image_input.py +++ /dev/null @@ -1,110 +0,0 @@ -"""Unit tests for ``supports_image_input`` derivation on BYOK chat config -endpoints (``GET /new-llm-configs`` list, ``GET /new-llm-configs/{id}``). - -There is no DB column for ``supports_image_input`` on -``NewLLMConfig`` — the value is resolved at the API boundary by -``derive_supports_image_input`` so the new-chat selector / streaming -task can read the same field shape regardless of source (BYOK vs YAML -vs OpenRouter dynamic). Default-allow on unknown so we don't lock the -user out of their own model choice. -""" - -from __future__ import annotations - -from datetime import UTC, datetime -from types import SimpleNamespace -from uuid import uuid4 - -import pytest - -from app.db import LiteLLMProvider -from app.routes import new_llm_config_routes - -pytestmark = pytest.mark.unit - - -def _byok_row( - *, - id_: int, - model_name: str, - base_model: str | None = None, - provider: LiteLLMProvider = LiteLLMProvider.OPENAI, - custom_provider: str | None = None, -) -> object: - """Mimic the SQLAlchemy row's attribute surface; ``model_validate`` - walks ``from_attributes=True`` so a ``SimpleNamespace`` is enough. - - ``provider`` is a real ``LiteLLMProvider`` enum value so Pydantic's - enum validator accepts it — same as the ORM row would carry.""" - return SimpleNamespace( - id=id_, - name=f"BYOK-{id_}", - description=None, - provider=provider, - custom_provider=custom_provider, - model_name=model_name, - api_key="sk-byok", - api_base=None, - litellm_params={"base_model": base_model} if base_model else None, - system_instructions="", - use_default_system_instructions=True, - citations_enabled=True, - created_at=datetime.now(tz=UTC), - search_space_id=42, - user_id=uuid4(), - ) - - -def test_serialize_byok_known_vision_model_resolves_true(): - """The catalog resolver consults LiteLLM's map for ``gpt-4o`` -> - True. The serialized row carries that value through to the - ``NewLLMConfigRead`` schema.""" - row = _byok_row(id_=1, model_name="gpt-4o") - serialized = new_llm_config_routes._serialize_byok_config(row) - - assert serialized.supports_image_input is True - assert serialized.id == 1 - assert serialized.model_name == "gpt-4o" - - -def test_serialize_byok_unknown_model_default_allows(): - """Unknown / unmapped: default-allow. The streaming-task safety net - is the actual block, and it requires LiteLLM to *explicitly* say - text-only — so a brand new BYOK model should not be pre-judged.""" - row = _byok_row( - id_=2, - model_name="brand-new-model-x9-unmapped", - provider=LiteLLMProvider.CUSTOM, - custom_provider="brand_new_proxy", - ) - serialized = new_llm_config_routes._serialize_byok_config(row) - - assert serialized.supports_image_input is True - - -def test_serialize_byok_uses_base_model_when_present(): - """Azure-style: ``model_name`` is the deployment id, ``base_model`` - inside ``litellm_params`` is the canonical sku LiteLLM knows. The - helper must consult ``base_model`` first or unrecognised deployment - ids would shadow the real capability.""" - row = _byok_row( - id_=3, - model_name="my-azure-deployment-id-no-litellm-knows-this", - base_model="gpt-4o", - provider=LiteLLMProvider.AZURE_OPENAI, - ) - serialized = new_llm_config_routes._serialize_byok_config(row) - - assert serialized.supports_image_input is True - - -def test_serialize_byok_returns_pydantic_read_model(): - """The route now returns ``NewLLMConfigRead`` (not the raw ORM) so - the schema additions are guaranteed to be present in the API - surface. This guards against a future regression where someone - deletes the augmentation step and falls back to ORM passthrough.""" - from app.schemas import NewLLMConfigRead - - row = _byok_row(id_=4, model_name="gpt-4o") - serialized = new_llm_config_routes._serialize_byok_config(row) - assert isinstance(serialized, NewLLMConfigRead) diff --git a/surfsense_backend/tests/unit/routes/test_global_configs_is_premium.py b/surfsense_backend/tests/unit/routes/test_global_configs_is_premium.py deleted file mode 100644 index fff61f14b..000000000 --- a/surfsense_backend/tests/unit/routes/test_global_configs_is_premium.py +++ /dev/null @@ -1,184 +0,0 @@ -"""Unit tests for ``is_premium`` derivation on the global image-gen and -vision-LLM list endpoints. - -Chat globals (``GET /global-llm-configs``) already emit -``is_premium = (billing_tier == "premium")``. Image and vision did not, -which made the new-chat ``model-selector`` render the Free/Premium badge -on the Chat tab but skip it on the Image and Vision tabs (the selector -keys its badge logic off ``is_premium``). These tests pin parity: - -* YAML free entry → ``is_premium=False`` -* YAML premium entry → ``is_premium=True`` -* OpenRouter dynamic premium entry → ``is_premium=True`` -* Auto stub (always emitted when at least one config is present) - → ``is_premium=False`` -""" - -from __future__ import annotations - -import pytest - -pytestmark = pytest.mark.unit - - -_IMAGE_FIXTURE: list[dict] = [ - { - "id": -1, - "name": "DALL-E 3", - "litellm_provider": "openai", - "model_name": "dall-e-3", - "api_key": "sk-test", - "billing_tier": "free", - }, - { - "id": -2, - "name": "GPT-Image 1 (premium)", - "litellm_provider": "openai", - "model_name": "gpt-image-1", - "api_key": "sk-test", - "billing_tier": "premium", - }, - { - "id": -20_001, - "name": "google/gemini-2.5-flash-image (OpenRouter)", - "litellm_provider": "openrouter", - "model_name": "google/gemini-2.5-flash-image", - "api_key": "sk-or-test", - "api_base": "https://openrouter.ai/api/v1", - "billing_tier": "premium", - }, -] - - -_VISION_FIXTURE: list[dict] = [ - { - "id": -1, - "name": "GPT-4o Vision", - "litellm_provider": "openai", - "model_name": "gpt-4o", - "api_key": "sk-test", - "billing_tier": "free", - }, - { - "id": -2, - "name": "Claude 3.5 Sonnet (premium)", - "litellm_provider": "anthropic", - "model_name": "claude-3-5-sonnet", - "api_key": "sk-ant-test", - "billing_tier": "premium", - }, - { - "id": -30_001, - "name": "openai/gpt-4o (OpenRouter)", - "litellm_provider": "openrouter", - "model_name": "openai/gpt-4o", - "api_key": "sk-or-test", - "api_base": "https://openrouter.ai/api/v1", - "billing_tier": "premium", - }, -] - - -# ============================================================================= -# Image generation -# ============================================================================= - - -@pytest.mark.asyncio -async def test_global_image_gen_configs_emit_is_premium(monkeypatch): - """Each emitted config must carry ``is_premium`` derived server-side - from ``billing_tier``. The Auto stub is always free. - """ - from app.config import config - from app.routes import image_generation_routes - - monkeypatch.setattr( - config, "GLOBAL_IMAGE_GEN_CONFIGS", _IMAGE_FIXTURE, raising=False - ) - - payload = await image_generation_routes.get_global_image_gen_configs(user=None) - - by_id = {c["id"]: c for c in payload} - - # Auto stub is always emitted when at least one global config exists, - # and it must always declare itself free (Auto-mode billing-tier - # surfacing is a separate follow-up). - assert 0 in by_id, "Auto stub should be emitted when at least one config exists" - assert by_id[0]["is_premium"] is False - assert by_id[0]["billing_tier"] == "free" - - # YAML free entry — ``is_premium=False`` - assert by_id[-1]["is_premium"] is False - assert by_id[-1]["billing_tier"] == "free" - - # YAML premium entry — ``is_premium=True`` - assert by_id[-2]["is_premium"] is True - assert by_id[-2]["billing_tier"] == "premium" - - # OpenRouter dynamic premium entry — same field, same derivation - assert by_id[-20_001]["is_premium"] is True - assert by_id[-20_001]["billing_tier"] == "premium" - - # Every emitted dict (including Auto) must have the field — never missing. - for cfg in payload: - assert "is_premium" in cfg, f"is_premium missing from {cfg.get('id')}" - assert isinstance(cfg["is_premium"], bool) - - -@pytest.mark.asyncio -async def test_global_image_gen_configs_no_globals_no_auto_stub(monkeypatch): - """When there are no global configs at all, the endpoint emits an - empty list (no Auto stub) — Auto mode would have nothing to route to. - """ - from app.config import config - from app.routes import image_generation_routes - - monkeypatch.setattr(config, "GLOBAL_IMAGE_GEN_CONFIGS", [], raising=False) - payload = await image_generation_routes.get_global_image_gen_configs(user=None) - assert payload == [] - - -# ============================================================================= -# Vision LLM -# ============================================================================= - - -@pytest.mark.asyncio -async def test_global_vision_llm_configs_emit_is_premium(monkeypatch): - from app.config import config - from app.routes import vision_llm_routes - - monkeypatch.setattr( - config, "GLOBAL_VISION_LLM_CONFIGS", _VISION_FIXTURE, raising=False - ) - - payload = await vision_llm_routes.get_global_vision_llm_configs(user=None) - - by_id = {c["id"]: c for c in payload} - - assert 0 in by_id, "Auto stub should be emitted when at least one config exists" - assert by_id[0]["is_premium"] is False - assert by_id[0]["billing_tier"] == "free" - - assert by_id[-1]["is_premium"] is False - assert by_id[-1]["billing_tier"] == "free" - - assert by_id[-2]["is_premium"] is True - assert by_id[-2]["billing_tier"] == "premium" - - assert by_id[-30_001]["is_premium"] is True - assert by_id[-30_001]["billing_tier"] == "premium" - - for cfg in payload: - assert "is_premium" in cfg, f"is_premium missing from {cfg.get('id')}" - assert isinstance(cfg["is_premium"], bool) - - -@pytest.mark.asyncio -async def test_global_vision_llm_configs_no_globals_no_auto_stub(monkeypatch): - from app.config import config - from app.routes import vision_llm_routes - - monkeypatch.setattr(config, "GLOBAL_VISION_LLM_CONFIGS", [], raising=False) - payload = await vision_llm_routes.get_global_vision_llm_configs(user=None) - assert payload == [] diff --git a/surfsense_backend/tests/unit/routes/test_global_new_llm_configs_supports_image.py b/surfsense_backend/tests/unit/routes/test_global_new_llm_configs_supports_image.py deleted file mode 100644 index 67d1112f3..000000000 --- a/surfsense_backend/tests/unit/routes/test_global_new_llm_configs_supports_image.py +++ /dev/null @@ -1,106 +0,0 @@ -"""Unit tests for ``supports_image_input`` derivation on the chat global -config endpoint (``GET /global-new-llm-configs``). - -Resolution order (matches ``new_llm_config_routes.get_global_new_llm_configs``): - -1. Explicit ``supports_image_input`` on the cfg dict (set by the YAML - loader for operator overrides, or by the OpenRouter integration from - ``architecture.input_modalities``) — wins. -2. ``derive_supports_image_input`` helper — default-allow on unknown - models, only False when LiteLLM / OR modalities are definitive. - -The flag is purely informational at the API boundary. The streaming -task safety net (``is_known_text_only_chat_model``) is the actual block, -and it requires LiteLLM to *explicitly* mark the model as text-only. -""" - -from __future__ import annotations - -import pytest - -pytestmark = pytest.mark.unit - - -_FIXTURE: list[dict] = [ - { - "id": -1, - "name": "GPT-4o (explicit true)", - "description": "vision-capable, explicit YAML override", - "litellm_provider": "openai", - "model_name": "gpt-4o", - "api_key": "sk-test", - "billing_tier": "free", - "supports_image_input": True, - }, - { - "id": -2, - "name": "DeepSeek V3 (explicit false)", - "description": "OpenRouter dynamic — modality-derived false", - "litellm_provider": "openrouter", - "model_name": "deepseek/deepseek-v3.2-exp", - "api_key": "sk-or-test", - "api_base": "https://openrouter.ai/api/v1", - "billing_tier": "free", - "supports_image_input": False, - }, - { - "id": -10_010, - "name": "Unannotated GPT-4o", - "description": "no flag set — resolver should derive True via LiteLLM", - "litellm_provider": "openai", - "model_name": "gpt-4o", - "api_key": "sk-test", - "billing_tier": "free", - # supports_image_input intentionally absent - }, - { - "id": -10_011, - "name": "Unannotated unknown model", - "description": "unmapped — default-allow True", - "litellm_provider": "custom", - "custom_provider": "brand_new_proxy", - "model_name": "brand-new-model-x9", - "api_key": "sk-test", - "billing_tier": "free", - }, -] - - -@pytest.mark.asyncio -async def test_global_new_llm_configs_emit_supports_image_input(monkeypatch): - """Each emitted chat config carries ``supports_image_input`` as a - bool. Explicit values win; unannotated entries are resolved via the - helper (default-allow True).""" - from app.config import config - from app.routes import new_llm_config_routes - - monkeypatch.setattr(config, "GLOBAL_LLM_CONFIGS", _FIXTURE, raising=False) - - payload = await new_llm_config_routes.get_global_new_llm_configs(user=None) - by_id = {c["id"]: c for c in payload} - - # Auto stub: optimistic True so the user can keep Auto selected with - # vision-capable deployments somewhere in the pool. - assert 0 in by_id, "Auto stub should be emitted when configs exist" - assert by_id[0]["supports_image_input"] is True - assert by_id[0]["is_auto_mode"] is True - - # Explicit True is preserved. - assert by_id[-1]["supports_image_input"] is True - - # Explicit False is preserved (the exact failure mode the safety net - # guards against — DeepSeek V3 over OpenRouter would 404 with "No - # endpoints found that support image input"). - assert by_id[-2]["supports_image_input"] is False - - # Unannotated GPT-4o: resolver consults LiteLLM, which says vision. - assert by_id[-10_010]["supports_image_input"] is True - - # Unknown / unmapped model: default-allow rather than pre-judge. - assert by_id[-10_011]["supports_image_input"] is True - - for cfg in payload: - assert "supports_image_input" in cfg, ( - f"supports_image_input missing from {cfg.get('id')}" - ) - assert isinstance(cfg["supports_image_input"], bool) diff --git a/surfsense_backend/tests/unit/routes/test_image_gen_quota.py b/surfsense_backend/tests/unit/routes/test_image_gen_quota.py index 53c0f50a9..4dd918927 100644 --- a/surfsense_backend/tests/unit/routes/test_image_gen_quota.py +++ b/surfsense_backend/tests/unit/routes/test_image_gen_quota.py @@ -27,9 +27,18 @@ async def test_resolve_billing_for_auto_mode(monkeypatch): from app.routes import image_generation_routes from app.services.billable_calls import DEFAULT_IMAGE_RESERVE_MICROS - search_space = SimpleNamespace(image_generation_config_id=None) + async def _no_auto_candidates(*_args, **_kwargs): + return [] + + monkeypatch.setattr( + image_generation_routes, + "auto_model_candidates", + _no_auto_candidates, + ) + + search_space = SimpleNamespace(id=1, user_id=None, image_gen_model_id=None) tier, model, reserve = await image_generation_routes._resolve_billing_for_image_gen( - session=None, # Not consumed on this code path. + session=None, config_id=0, # IMAGE_GEN_AUTO_MODE_ID search_space=search_space, ) @@ -45,26 +54,42 @@ async def test_resolve_billing_for_premium_global_config(monkeypatch): monkeypatch.setattr( config, - "GLOBAL_IMAGE_GEN_CONFIGS", + "GLOBAL_MODELS", [ { "id": -1, - "litellm_provider": "openai", - "model_name": "gpt-image-1", + "connection_id": -101, + "model_id": "gpt-image-1", "billing_tier": "premium", - "quota_reserve_micros": 75_000, + "catalog": {"quota_reserve_micros": 75_000}, }, { "id": -2, - "litellm_provider": "openrouter", - "model_name": "google/gemini-2.5-flash-image", + "connection_id": -102, + "model_id": "google/gemini-2.5-flash-image", "billing_tier": "free", + "catalog": {}, + }, + ], + raising=False, + ) + monkeypatch.setattr( + config, + "GLOBAL_CONNECTIONS", + [ + {"id": -101, "provider": "openai", "api_key": "sk-test", "base_url": None, "extra": {}}, + { + "id": -102, + "provider": "openrouter", + "api_key": "sk-or-test", + "base_url": "https://openrouter.ai/api/v1", + "extra": {}, }, ], raising=False, ) - search_space = SimpleNamespace(image_generation_config_id=None) + search_space = SimpleNamespace(id=1, user_id=None, image_gen_model_id=None) # Premium with override. tier, model, reserve = await image_generation_routes._resolve_billing_for_image_gen( @@ -94,7 +119,7 @@ async def test_resolve_billing_for_user_owned_byok_is_free(): from app.routes import image_generation_routes from app.services.billable_calls import DEFAULT_IMAGE_RESERVE_MICROS - search_space = SimpleNamespace(image_generation_config_id=None) + search_space = SimpleNamespace(id=1, user_id=None, image_gen_model_id=None) tier, model, reserve = await image_generation_routes._resolve_billing_for_image_gen( session=None, config_id=42, search_space=search_space ) @@ -105,7 +130,7 @@ async def test_resolve_billing_for_user_owned_byok_is_free(): @pytest.mark.asyncio async def test_resolve_billing_falls_back_to_search_space_default(monkeypatch): - """When the request omits ``image_generation_config_id``, the helper + """When the request omits ``image_gen_model_id``, the helper must consult the search space's default — so a search space pinned to a premium global config still gates new requests by quota. """ @@ -114,19 +139,26 @@ async def test_resolve_billing_falls_back_to_search_space_default(monkeypatch): monkeypatch.setattr( config, - "GLOBAL_IMAGE_GEN_CONFIGS", + "GLOBAL_MODELS", [ { "id": -7, - "litellm_provider": "openai", - "model_name": "gpt-image-1", + "connection_id": -101, + "model_id": "gpt-image-1", "billing_tier": "premium", + "catalog": {}, } ], raising=False, ) + monkeypatch.setattr( + config, + "GLOBAL_CONNECTIONS", + [{"id": -101, "provider": "openai", "api_key": "sk-test", "base_url": None, "extra": {}}], + raising=False, + ) - search_space = SimpleNamespace(image_generation_config_id=-7) + search_space = SimpleNamespace(id=1, user_id=None, image_gen_model_id=-7) ( tier, model, diff --git a/surfsense_backend/tests/unit/services/test_agent_billing_resolver.py b/surfsense_backend/tests/unit/services/test_agent_billing_resolver.py index fa8819b39..b43540ba7 100644 --- a/surfsense_backend/tests/unit/services/test_agent_billing_resolver.py +++ b/surfsense_backend/tests/unit/services/test_agent_billing_resolver.py @@ -1,27 +1,4 @@ -"""Unit tests for ``_resolve_agent_billing_for_search_space``. - -Validates the resolver used by Celery podcast/video tasks to compute -``(owner_user_id, billing_tier, base_model)`` from a search space and its -agent LLM config. The resolver mirrors chat's billing-resolution pattern at -``stream_new_chat.py:2294-2351`` and is the single integration point that -prevents Auto-mode podcast/video from leaking premium credit. - -Coverage: - -* Auto mode + ``thread_id`` set, pin resolves to a negative-id premium - global → returns ``("premium", )``. -* Auto mode + ``thread_id`` set, pin resolves to a negative-id free - global → returns ``("free", )``. -* Auto mode + ``thread_id`` set, pin resolves to a positive-id BYOK config - → always ``"free"``. -* Auto mode + ``thread_id=None`` → fallback to ``("free", "auto")`` without - hitting the pin service. -* Negative id (no Auto) → uses ``get_global_llm_config``'s - ``billing_tier``. -* Positive id (user BYOK) → always ``"free"``. -* Search space not found → raises ``ValueError``. -* ``agent_llm_id`` is None → raises ``ValueError``. -""" +"""Unit tests for ``_resolve_agent_billing_for_search_space``.""" from __future__ import annotations @@ -34,11 +11,6 @@ import pytest pytestmark = pytest.mark.unit -# --------------------------------------------------------------------------- -# Fakes -# --------------------------------------------------------------------------- - - class _FakeExecResult: def __init__(self, obj): self._obj = obj @@ -51,14 +23,6 @@ class _FakeExecResult: class _FakeSession: - """Tiny AsyncSession stub. - - ``responses`` is a list of objects to return from successive - ``execute()`` calls (in order). The resolver makes at most two - ``execute()`` calls (search-space lookup, then optionally NewLLMConfig - lookup), so two queued responses cover the matrix. - """ - def __init__(self, responses: list): self._responses = list(responses) @@ -67,9 +31,6 @@ class _FakeSession: return _FakeExecResult(None) return _FakeExecResult(self._responses.pop(0)) - async def commit(self) -> None: - pass - @dataclass class _FakePinResolution: @@ -78,53 +39,33 @@ class _FakePinResolution: from_existing_pin: bool = False -def _make_search_space(*, agent_llm_id: int | None, user_id: UUID) -> SimpleNamespace: - return SimpleNamespace( - id=42, - agent_llm_id=agent_llm_id, - user_id=user_id, - ) +def _make_search_space(*, chat_model_id: int | None, user_id: UUID) -> SimpleNamespace: + return SimpleNamespace(id=42, chat_model_id=chat_model_id, user_id=user_id) -def _make_byok_config( - *, id_: int, base_model: str | None = None, model_name: str = "gpt-byok" +def _make_byok_model( + *, id_: int, base_model: str | None = None, model_id: str = "gpt-byok" ) -> SimpleNamespace: return SimpleNamespace( id=id_, - model_name=model_name, - litellm_params={"base_model": base_model} if base_model else {}, + model_id=model_id, + catalog={"base_model": base_model} if base_model else {}, + connection=SimpleNamespace(enabled=True, search_space_id=42, user_id=None), ) -# --------------------------------------------------------------------------- -# Tests -# --------------------------------------------------------------------------- - - @pytest.mark.asyncio async def test_auto_mode_with_thread_id_resolves_to_premium_global(monkeypatch): - """Auto + thread → pin service resolves to negative-id premium config → - resolver returns ``("premium", )``.""" from app.services.billable_calls import _resolve_agent_billing_for_search_space user_id = uuid4() - session = _FakeSession([_make_search_space(agent_llm_id=0, user_id=user_id)]) + session = _FakeSession([_make_search_space(chat_model_id=0, user_id=user_id)]) - # Mock the pin service to return a concrete premium config id. - async def _fake_resolve_pin( - sess, - *, - thread_id, - search_space_id, - user_id, - selected_llm_config_id, - force_repin_free=False, - ): - assert selected_llm_config_id == 0 - assert thread_id == 99 + async def _fake_resolve_pin(*_args, **kwargs): + assert kwargs["selected_llm_config_id"] == 0 + assert kwargs["thread_id"] == 99 return _FakePinResolution(resolved_llm_config_id=-1, resolved_tier="premium") - # Mock global config lookup to return a premium entry. def _fake_get_global(cfg_id): if cfg_id == -1: return { @@ -135,8 +76,6 @@ async def test_auto_mode_with_thread_id_resolves_to_premium_global(monkeypatch): } return None - # Lazy imports inside the resolver — patch the *target* modules so the - # imported names resolve to our fakes. import app.services.auto_model_pin_service as pin_module import app.services.llm_service as llm_module @@ -154,77 +93,18 @@ async def test_auto_mode_with_thread_id_resolves_to_premium_global(monkeypatch): assert base_model == "gpt-5.4" -@pytest.mark.asyncio -async def test_auto_mode_with_thread_id_resolves_to_free_global(monkeypatch): - """Auto + thread → pin returns negative-id free config → resolver - returns ``("free", )``. Same path the pin service takes for - out-of-credit users (graceful degradation).""" - from app.services.billable_calls import _resolve_agent_billing_for_search_space - - user_id = uuid4() - session = _FakeSession([_make_search_space(agent_llm_id=0, user_id=user_id)]) - - async def _fake_resolve_pin( - sess, - *, - thread_id, - search_space_id, - user_id, - selected_llm_config_id, - force_repin_free=False, - ): - return _FakePinResolution(resolved_llm_config_id=-3, resolved_tier="free") - - def _fake_get_global(cfg_id): - if cfg_id == -3: - return { - "id": -3, - "model_name": "openrouter/free-model", - "billing_tier": "free", - "litellm_params": {"base_model": "openrouter/free-model"}, - } - return None - - import app.services.auto_model_pin_service as pin_module - import app.services.llm_service as llm_module - - monkeypatch.setattr( - pin_module, "resolve_or_get_pinned_llm_config_id", _fake_resolve_pin - ) - monkeypatch.setattr(llm_module, "get_global_llm_config", _fake_get_global) - - owner, tier, base_model = await _resolve_agent_billing_for_search_space( - session, search_space_id=42, thread_id=99 - ) - - assert owner == user_id - assert tier == "free" - assert base_model == "openrouter/free-model" - - @pytest.mark.asyncio async def test_auto_mode_with_thread_id_resolves_to_byok_is_free(monkeypatch): - """Auto + thread → pin returns positive-id BYOK config → resolver - returns ``("free", ...)`` (BYOK is always free per - ``AgentConfig.from_new_llm_config``).""" from app.services.billable_calls import _resolve_agent_billing_for_search_space user_id = uuid4() - search_space = _make_search_space(agent_llm_id=0, user_id=user_id) - byok_cfg = _make_byok_config( - id_=17, base_model="anthropic/claude-3-haiku", model_name="my-claude" + search_space = _make_search_space(chat_model_id=0, user_id=user_id) + byok_model = _make_byok_model( + id_=17, base_model="anthropic/claude-3-haiku", model_id="my-claude" ) - session = _FakeSession([search_space, byok_cfg]) + session = _FakeSession([search_space, byok_model]) - async def _fake_resolve_pin( - sess, - *, - thread_id, - search_space_id, - user_id, - selected_llm_config_id, - force_repin_free=False, - ): + async def _fake_resolve_pin(*_args, **_kwargs): return _FakePinResolution(resolved_llm_config_id=17, resolved_tier="free") import app.services.auto_model_pin_service as pin_module @@ -244,13 +124,10 @@ async def test_auto_mode_with_thread_id_resolves_to_byok_is_free(monkeypatch): @pytest.mark.asyncio async def test_auto_mode_without_thread_id_falls_back_to_free(): - """Auto + ``thread_id=None`` → ``("free", "auto")`` without invoking - the pin service. Forward-compat fallback for any future direct-API - entrypoint that doesn't have a chat thread.""" from app.services.billable_calls import _resolve_agent_billing_for_search_space user_id = uuid4() - session = _FakeSession([_make_search_space(agent_llm_id=0, user_id=user_id)]) + session = _FakeSession([_make_search_space(chat_model_id=0, user_id=user_id)]) owner, tier, base_model = await _resolve_agent_billing_for_search_space( session, search_space_id=42, thread_id=None @@ -263,13 +140,10 @@ async def test_auto_mode_without_thread_id_falls_back_to_free(): @pytest.mark.asyncio async def test_auto_mode_pin_failure_falls_back_to_free(monkeypatch): - """If the pin service raises ``ValueError`` (thread missing / - mismatched search space), the resolver should log and return free - rather than killing the whole task.""" from app.services.billable_calls import _resolve_agent_billing_for_search_space user_id = uuid4() - session = _FakeSession([_make_search_space(agent_llm_id=0, user_id=user_id)]) + session = _FakeSession([_make_search_space(chat_model_id=0, user_id=user_id)]) async def _fake_resolve_pin(*args, **kwargs): raise ValueError("thread missing") @@ -291,12 +165,10 @@ async def test_auto_mode_pin_failure_falls_back_to_free(monkeypatch): @pytest.mark.asyncio async def test_negative_id_premium_global_returns_premium(monkeypatch): - """Explicit negative agent_llm_id → ``get_global_llm_config`` → - return its ``billing_tier``.""" from app.services.billable_calls import _resolve_agent_billing_for_search_space user_id = uuid4() - session = _FakeSession([_make_search_space(agent_llm_id=-1, user_id=user_id)]) + session = _FakeSession([_make_search_space(chat_model_id=-1, user_id=user_id)]) def _fake_get_global(cfg_id): return { @@ -319,50 +191,15 @@ async def test_negative_id_premium_global_returns_premium(monkeypatch): assert base_model == "gpt-5.4" -@pytest.mark.asyncio -async def test_negative_id_free_global_returns_free(monkeypatch): - from app.services.billable_calls import _resolve_agent_billing_for_search_space - - user_id = uuid4() - session = _FakeSession([_make_search_space(agent_llm_id=-2, user_id=user_id)]) - - def _fake_get_global(cfg_id): - return { - "id": cfg_id, - "model_name": "openrouter/some-free", - "billing_tier": "free", - "litellm_params": {"base_model": "openrouter/some-free"}, - } - - import app.services.llm_service as llm_module - - monkeypatch.setattr(llm_module, "get_global_llm_config", _fake_get_global) - - owner, tier, base_model = await _resolve_agent_billing_for_search_space( - session, search_space_id=42, thread_id=None - ) - - assert owner == user_id - assert tier == "free" - assert base_model == "openrouter/some-free" - - @pytest.mark.asyncio async def test_negative_id_missing_base_model_falls_back_to_model_name(monkeypatch): - """When the global config has no ``litellm_params.base_model``, the - resolver falls back to ``model_name`` — matching chat's behavior.""" from app.services.billable_calls import _resolve_agent_billing_for_search_space user_id = uuid4() - session = _FakeSession([_make_search_space(agent_llm_id=-5, user_id=user_id)]) + session = _FakeSession([_make_search_space(chat_model_id=-5, user_id=user_id)]) def _fake_get_global(cfg_id): - return { - "id": cfg_id, - "model_name": "fallback-model", - "billing_tier": "premium", - # No litellm_params. - } + return {"id": cfg_id, "model_name": "fallback-model", "billing_tier": "premium"} import app.services.llm_service as llm_module @@ -378,14 +215,12 @@ async def test_negative_id_missing_base_model_falls_back_to_model_name(monkeypat @pytest.mark.asyncio async def test_positive_id_byok_is_always_free(): - """Positive agent_llm_id → user-owned BYOK NewLLMConfig → always free, - regardless of underlying provider tier.""" from app.services.billable_calls import _resolve_agent_billing_for_search_space user_id = uuid4() - search_space = _make_search_space(agent_llm_id=23, user_id=user_id) - byok_cfg = _make_byok_config(id_=23, base_model="anthropic/claude-3.5-sonnet") - session = _FakeSession([search_space, byok_cfg]) + search_space = _make_search_space(chat_model_id=23, user_id=user_id) + byok_model = _make_byok_model(id_=23, base_model="anthropic/claude-3.5-sonnet") + session = _FakeSession([search_space, byok_model]) owner, tier, base_model = await _resolve_agent_billing_for_search_space( session, search_space_id=42 @@ -398,13 +233,10 @@ async def test_positive_id_byok_is_always_free(): @pytest.mark.asyncio async def test_positive_id_byok_missing_returns_free_with_empty_base_model(): - """If the BYOK config row is missing/deleted but the search space still - points at it, the resolver still returns free (no debit) with an empty - base_model — billable_call's premium path is skipped, no harm done.""" from app.services.billable_calls import _resolve_agent_billing_for_search_space user_id = uuid4() - session = _FakeSession([_make_search_space(agent_llm_id=99, user_id=user_id)]) + session = _FakeSession([_make_search_space(chat_model_id=99, user_id=user_id)]) owner, tier, base_model = await _resolve_agent_billing_for_search_space( session, search_space_id=42 @@ -419,18 +251,18 @@ async def test_positive_id_byok_missing_returns_free_with_empty_base_model(): async def test_search_space_not_found_raises_value_error(): from app.services.billable_calls import _resolve_agent_billing_for_search_space - session = _FakeSession([None]) - with pytest.raises(ValueError, match="Search space"): - await _resolve_agent_billing_for_search_space(session, search_space_id=999) + await _resolve_agent_billing_for_search_space( + _FakeSession([None]), search_space_id=999 + ) @pytest.mark.asyncio -async def test_agent_llm_id_none_raises_value_error(): +async def test_chat_model_id_none_raises_value_error(): from app.services.billable_calls import _resolve_agent_billing_for_search_space user_id = uuid4() - session = _FakeSession([_make_search_space(agent_llm_id=None, user_id=user_id)]) + session = _FakeSession([_make_search_space(chat_model_id=None, user_id=user_id)]) - with pytest.raises(ValueError, match="agent_llm_id"): + with pytest.raises(ValueError, match="chat_model_id"): await _resolve_agent_billing_for_search_space(session, search_space_id=42) diff --git a/surfsense_backend/tests/unit/services/test_auto_model_pin_service.py b/surfsense_backend/tests/unit/services/test_auto_model_pin_service.py index d7eb32732..d7c12a6e0 100644 --- a/surfsense_backend/tests/unit/services/test_auto_model_pin_service.py +++ b/surfsense_backend/tests/unit/services/test_auto_model_pin_service.py @@ -32,8 +32,9 @@ class _FakeQuotaResult: class _FakeExecResult: - def __init__(self, thread): + def __init__(self, *, thread=None, scalars=None): self._thread = thread + self._scalars = scalars or [] def unique(self): return self @@ -41,19 +42,69 @@ class _FakeExecResult: def scalar_one_or_none(self): return self._thread + def scalars(self): + return SimpleNamespace(all=lambda: self._scalars) + class _FakeSession: - def __init__(self, thread): + def __init__(self, thread, *, models=None): self.thread = thread + self.models = models or [] self.commit_count = 0 + self.execute_count = 0 async def execute(self, _stmt): - return _FakeExecResult(self.thread) + self.execute_count += 1 + if self.execute_count == 1: + return _FakeExecResult(thread=self.thread) + return _FakeExecResult(scalars=self.models) async def commit(self): self.commit_count += 1 +def _set_global_llm_configs(monkeypatch, config, configs: list[dict]): + """Patch the new global model catalog shape from compact legacy cfg fixtures.""" + connections = [] + models = [] + for cfg in configs: + config_id = int(cfg["id"]) + connection_id = config_id - 100_000 + provider = cfg.get("provider") or cfg.get("litellm_provider") + model_name = cfg["model_name"] + connections.append( + { + "id": connection_id, + "provider": provider, + "scope": "GLOBAL", + "enabled": True, + } + ) + models.append( + { + "id": config_id, + "connection_id": connection_id, + "model_id": model_name, + "display_name": cfg.get("name") or model_name, + "supports_chat": cfg.get("supports_chat", True), + "supports_image_input": cfg.get("supports_image_input", True), + "supports_tools": cfg.get("supports_tools", True), + "supports_image_generation": cfg.get("supports_image_generation", False), + "capabilities_override": cfg.get("capabilities_override") or {}, + "billing_tier": cfg.get("billing_tier", "free"), + "catalog": { + "auto_pin_tier": cfg.get("auto_pin_tier"), + "quality_score": cfg.get("quality_score") + or cfg.get("quality_score_static"), + }, + } + ) + + monkeypatch.setattr(config, "GLOBAL_LLM_CONFIGS", configs) + monkeypatch.setattr(config, "GLOBAL_CONNECTIONS", connections) + monkeypatch.setattr(config, "GLOBAL_MODELS", models) + + def _thread( *, search_space_id: int = 10, @@ -71,9 +122,9 @@ async def test_auto_first_turn_pins_one_model(monkeypatch): from app.config import config session = _FakeSession(_thread()) - monkeypatch.setattr( + _set_global_llm_configs( + monkeypatch, config, - "GLOBAL_LLM_CONFIGS", [ {"id": -2, "litellm_provider": "openai", "model_name": "gpt-free", "api_key": "k1"}, { @@ -111,9 +162,9 @@ async def test_premium_eligible_auto_prefers_premium_over_free(monkeypatch): from app.config import config session = _FakeSession(_thread()) - monkeypatch.setattr( + _set_global_llm_configs( + monkeypatch, config, - "GLOBAL_LLM_CONFIGS", [ { "id": -2, @@ -158,9 +209,9 @@ async def test_premium_eligible_auto_prefers_azure_gpt_5_4(monkeypatch): from app.config import config session = _FakeSession(_thread()) - monkeypatch.setattr( + _set_global_llm_configs( + monkeypatch, config, - "GLOBAL_LLM_CONFIGS", [ { "id": -1, @@ -216,9 +267,9 @@ async def test_next_turn_reuses_existing_pin(monkeypatch): from app.config import config session = _FakeSession(_thread(pinned_llm_config_id=-1)) - monkeypatch.setattr( + _set_global_llm_configs( + monkeypatch, config, - "GLOBAL_LLM_CONFIGS", [ { "id": -1, @@ -257,9 +308,9 @@ async def test_premium_eligible_auto_can_pin_premium(monkeypatch): from app.config import config session = _FakeSession(_thread()) - monkeypatch.setattr( + _set_global_llm_configs( + monkeypatch, config, - "GLOBAL_LLM_CONFIGS", [ { "id": -1, @@ -295,9 +346,9 @@ async def test_premium_ineligible_auto_pins_free_only(monkeypatch): from app.config import config session = _FakeSession(_thread()) - monkeypatch.setattr( + _set_global_llm_configs( + monkeypatch, config, - "GLOBAL_LLM_CONFIGS", [ { "id": -2, @@ -340,9 +391,9 @@ async def test_pinned_premium_stays_premium_after_quota_exhaustion(monkeypatch): from app.config import config session = _FakeSession(_thread(pinned_llm_config_id=-1)) - monkeypatch.setattr( + _set_global_llm_configs( + monkeypatch, config, - "GLOBAL_LLM_CONFIGS", [ { "id": -2, @@ -385,9 +436,9 @@ async def test_force_repin_free_switches_auto_premium_pin_to_free(monkeypatch): from app.config import config session = _FakeSession(_thread(pinned_llm_config_id=-1)) - monkeypatch.setattr( + _set_global_llm_configs( + monkeypatch, config, - "GLOBAL_LLM_CONFIGS", [ { "id": -2, @@ -433,9 +484,9 @@ async def test_explicit_user_model_change_clears_pin(monkeypatch): from app.config import config session = _FakeSession(_thread(pinned_llm_config_id=-2)) - monkeypatch.setattr( + _set_global_llm_configs( + monkeypatch, config, - "GLOBAL_LLM_CONFIGS", [ {"id": -2, "litellm_provider": "openai", "model_name": "gpt-free", "api_key": "k1"}, ], @@ -458,9 +509,9 @@ async def test_invalid_pinned_config_repairs_with_new_pin(monkeypatch): from app.config import config session = _FakeSession(_thread(pinned_llm_config_id=-999)) - monkeypatch.setattr( + _set_global_llm_configs( + monkeypatch, config, - "GLOBAL_LLM_CONFIGS", [ {"id": -2, "litellm_provider": "openai", "model_name": "gpt-free", "api_key": "k1"}, ], @@ -487,7 +538,7 @@ async def test_invalid_pinned_config_repairs_with_new_pin(monkeypatch): # --------------------------------------------------------------------------- -# Quality-aware pin selection (Auto Fastest upgrade) +# Quality-aware pin selection (Auto upgrade) # --------------------------------------------------------------------------- @@ -498,9 +549,9 @@ async def test_health_gated_config_is_excluded_from_selection(monkeypatch): from app.config import config session = _FakeSession(_thread()) - monkeypatch.setattr( + _set_global_llm_configs( + monkeypatch, config, - "GLOBAL_LLM_CONFIGS", [ { "id": -1, @@ -550,9 +601,9 @@ async def test_tier_a_locks_first_premium_user_skips_or(monkeypatch): from app.config import config session = _FakeSession(_thread()) - monkeypatch.setattr( + _set_global_llm_configs( + monkeypatch, config, - "GLOBAL_LLM_CONFIGS", [ { "id": -1, @@ -602,9 +653,9 @@ async def test_tier_a_falls_through_to_or_when_a_pool_empty_for_user(monkeypatch from app.config import config session = _FakeSession(_thread()) - monkeypatch.setattr( + _set_global_llm_configs( + monkeypatch, config, - "GLOBAL_LLM_CONFIGS", [ { "id": -1, @@ -676,9 +727,9 @@ async def test_top_k_picks_only_high_score_models(monkeypatch): "quality_score": 10, "health_gated": False, } - monkeypatch.setattr( + _set_global_llm_configs( + monkeypatch, config, - "GLOBAL_LLM_CONFIGS", [*high_score_cfgs, low_score_trap], ) @@ -723,9 +774,9 @@ async def test_pin_reuse_survives_health_gating_for_existing_pin(monkeypatch): from app.config import config session = _FakeSession(_thread(pinned_llm_config_id=-1)) - monkeypatch.setattr( + _set_global_llm_configs( + monkeypatch, config, - "GLOBAL_LLM_CONFIGS", [ { "id": -1, @@ -775,9 +826,9 @@ async def test_pin_reuse_regression_existing_healthy_pin(monkeypatch): from app.config import config session = _FakeSession(_thread(pinned_llm_config_id=-1)) - monkeypatch.setattr( + _set_global_llm_configs( + monkeypatch, config, - "GLOBAL_LLM_CONFIGS", [ { "id": -1, @@ -833,9 +884,9 @@ async def test_runtime_cooled_down_pin_is_not_reused(monkeypatch): from app.config import config session = _FakeSession(_thread(pinned_llm_config_id=-1)) - monkeypatch.setattr( + _set_global_llm_configs( + monkeypatch, config, - "GLOBAL_LLM_CONFIGS", [ { "id": -1, @@ -886,9 +937,9 @@ async def test_clearing_runtime_cooldown_restores_pin_reuse(monkeypatch): from app.config import config session = _FakeSession(_thread(pinned_llm_config_id=-1)) - monkeypatch.setattr( + _set_global_llm_configs( + monkeypatch, config, - "GLOBAL_LLM_CONFIGS", [ { "id": -1, @@ -931,9 +982,9 @@ async def test_auto_pin_repin_excludes_previous_config_on_runtime_retry(monkeypa from app.config import config session = _FakeSession(_thread(pinned_llm_config_id=-1)) - monkeypatch.setattr( + _set_global_llm_configs( + monkeypatch, config, - "GLOBAL_LLM_CONFIGS", [ { "id": -1, diff --git a/surfsense_backend/tests/unit/services/test_image_gen_api_base_defense.py b/surfsense_backend/tests/unit/services/test_image_gen_api_base_defense.py index 63aa934a3..5850dfe23 100644 --- a/surfsense_backend/tests/unit/services/test_image_gen_api_base_defense.py +++ b/surfsense_backend/tests/unit/services/test_image_gen_api_base_defense.py @@ -15,15 +15,19 @@ async def test_global_openrouter_image_gen_sets_explicit_api_base(): """The global-config branch forwards the explicit OpenRouter base.""" from app.routes import image_generation_routes - cfg = { + global_model = { "id": -20_001, - "name": "GPT Image 1 (OpenRouter)", - "litellm_provider": "openrouter", - "model_name": "openai/gpt-image-1", + "connection_id": -101, + "model_id": "openai/gpt-image-1", + "supports_image_generation": True, + "capabilities_override": {}, + } + global_connection = { + "id": -101, + "provider": "openrouter", "api_key": "sk-or-test", - "api_base": "https://openrouter.ai/api/v1", - "api_version": None, - "litellm_params": {}, + "base_url": "https://openrouter.ai/api/v1", + "extra": {}, } captured: dict = {} @@ -33,7 +37,7 @@ async def test_global_openrouter_image_gen_sets_explicit_api_base(): return MagicMock(model_dump=lambda: {"data": []}, _hidden_params={}) image_gen = MagicMock() - image_gen.image_generation_config_id = cfg["id"] + image_gen.image_gen_model_id = global_model["id"] image_gen.prompt = "test" image_gen.n = 1 image_gen.quality = None @@ -43,14 +47,19 @@ async def test_global_openrouter_image_gen_sets_explicit_api_base(): image_gen.model = None search_space = MagicMock() - search_space.image_generation_config_id = cfg["id"] + search_space.image_gen_model_id = global_model["id"] session = MagicMock() with ( patch.object( image_generation_routes, - "_get_global_image_gen_config", - return_value=cfg, + "_get_global_model", + return_value=global_model, + ), + patch.object( + image_generation_routes, + "_get_global_connection", + return_value=global_connection, ), patch.object( image_generation_routes, @@ -74,15 +83,19 @@ async def test_generate_image_tool_global_sets_explicit_api_base(): generate_image as gi_module, ) - cfg = { + global_model = { "id": -20_001, - "name": "GPT Image 1 (OpenRouter)", - "litellm_provider": "openrouter", - "model_name": "openai/gpt-image-1", + "connection_id": -101, + "model_id": "openai/gpt-image-1", + "supports_image_generation": True, + "capabilities_override": {}, + } + global_connection = { + "id": -101, + "provider": "openrouter", "api_key": "sk-or-test", - "api_base": "https://openrouter.ai/api/v1", - "api_version": None, - "litellm_params": {}, + "base_url": "https://openrouter.ai/api/v1", + "extra": {}, } captured: dict = {} @@ -98,7 +111,7 @@ async def test_generate_image_tool_global_sets_explicit_api_base(): search_space = MagicMock() search_space.id = 1 - search_space.image_generation_config_id = cfg["id"] + search_space.image_gen_model_id = global_model["id"] session_cm = AsyncMock() session = AsyncMock() @@ -121,7 +134,8 @@ async def test_generate_image_tool_global_sets_explicit_api_base(): with ( patch.object(gi_module, "shielded_async_session", return_value=session_cm), - patch.object(gi_module, "_get_global_image_gen_config", return_value=cfg), + patch.object(gi_module, "_get_global_model", return_value=global_model), + patch.object(gi_module, "_get_global_connection", return_value=global_connection), patch.object( gi_module, "aimage_generation", side_effect=fake_aimage_generation ), diff --git a/surfsense_backend/tests/unit/services/test_openrouter_integration_service.py b/surfsense_backend/tests/unit/services/test_openrouter_integration_service.py index 9d4c1a04b..ee97aac4d 100644 --- a/surfsense_backend/tests/unit/services/test_openrouter_integration_service.py +++ b/surfsense_backend/tests/unit/services/test_openrouter_integration_service.py @@ -217,7 +217,7 @@ def test_generate_configs_drops_non_text_and_non_tool_models(): # --------------------------------------------------------------------------- -# _generate_image_gen_configs / _generate_vision_llm_configs +# _generate_image_gen_configs # --------------------------------------------------------------------------- @@ -263,7 +263,7 @@ def test_generate_image_gen_configs_filters_by_image_output(): # Each config must carry ``billing_tier`` for routing in image_generation_routes. for c in cfgs: assert c["billing_tier"] in {"free", "premium"} - assert c["litellm_provider"] == "openrouter" + assert c["provider"] == "openrouter" assert c[_OPENROUTER_DYNAMIC_MARKER] is True # Emit the OpenRouter base URL at source so every call path passes an # explicit api_base and cannot inherit a process-global endpoint. @@ -271,9 +271,7 @@ def test_generate_image_gen_configs_filters_by_image_output(): def test_generate_image_gen_configs_assigns_image_id_offset(): - """Image configs use a different id_offset (-20000) so their negative - IDs don't collide with chat configs (-10000) or vision configs (-30000). - """ + """Image configs use their own id_offset (-20000).""" from app.services.openrouter_integration_service import ( _generate_image_gen_configs, ) @@ -291,88 +289,3 @@ def test_generate_image_gen_configs_assigns_image_id_offset(): assert all(c["id"] < -20_000 + 1 for c in cfgs) assert all(c["id"] > -29_000_000 for c in cfgs) - -def test_generate_vision_llm_configs_filters_by_image_input_text_output(): - """Vision LLMs must accept image input AND emit text — pure image-gen - (no text out) and text-only (no image in) models are excluded. - """ - from app.services.openrouter_integration_service import ( - _generate_vision_llm_configs, - ) - - raw = [ - # GPT-4o: vision LLM (image in, text out) — must emit. - { - "id": "openai/gpt-4o", - "architecture": { - "input_modalities": ["text", "image"], - "output_modalities": ["text"], - }, - "context_length": 128_000, - "pricing": {"prompt": "0.000005", "completion": "0.000015"}, - }, - # Pure image generator — image *output*, no text out. Must NOT emit. - { - "id": "openai/gpt-image-1", - "architecture": { - "input_modalities": ["text"], - "output_modalities": ["image"], - }, - "context_length": 4_000, - "pricing": {"prompt": "0", "completion": "0"}, - }, - # Pure text model (no image in). Must NOT emit. - { - "id": "anthropic/claude-3-haiku", - "architecture": { - "input_modalities": ["text"], - "output_modalities": ["text"], - }, - "context_length": 200_000, - "pricing": {"prompt": "0.000001", "completion": "0.000005"}, - }, - ] - - cfgs = _generate_vision_llm_configs(raw, dict(_SETTINGS_BASE)) - names = {c["model_name"] for c in cfgs} - assert names == {"openai/gpt-4o"} - - cfg = cfgs[0] - assert cfg["billing_tier"] == "premium" - # Pricing carried inline so pricing_registration can register vision - # under ``openrouter/openai/gpt-4o`` even if the chat catalogue cache - # is cleared. - assert cfg["input_cost_per_token"] == pytest.approx(5e-6) - assert cfg["output_cost_per_token"] == pytest.approx(15e-6) - assert cfg[_OPENROUTER_DYNAMIC_MARKER] is True - # Emit the OpenRouter base URL at source so every call path passes an - # explicit api_base and cannot inherit a process-global endpoint. - assert cfg["api_base"] == "https://openrouter.ai/api/v1" - - -def test_generate_vision_llm_configs_drops_chat_only_filters(): - """A small-context vision model that doesn't advertise tool calling is - still a valid vision LLM for "describe this image" prompts. The chat - filters (``supports_tool_calling``, ``has_sufficient_context``) must - NOT be applied to vision emission. - """ - from app.services.openrouter_integration_service import ( - _generate_vision_llm_configs, - ) - - raw = [ - { - "id": "tiny/vision-mini", - "architecture": { - "input_modalities": ["text", "image"], - "output_modalities": ["text"], - }, - "supported_parameters": [], # no tools - "context_length": 4_000, # well below MIN_CONTEXT_LENGTH - "pricing": {"prompt": "0.0000001", "completion": "0.0000005"}, - } - ] - - cfgs = _generate_vision_llm_configs(raw, dict(_SETTINGS_BASE)) - assert len(cfgs) == 1 - assert cfgs[0]["model_name"] == "tiny/vision-mini" diff --git a/surfsense_backend/tests/unit/services/test_pricing_registration.py b/surfsense_backend/tests/unit/services/test_pricing_registration.py index c9adc6aac..ee2faf674 100644 --- a/surfsense_backend/tests/unit/services/test_pricing_registration.py +++ b/surfsense_backend/tests/unit/services/test_pricing_registration.py @@ -370,77 +370,3 @@ def test_register_continues_after_individual_failure(monkeypatch, caplog): assert any("custom-deployment" in payload for payload in successful_calls) -def test_vision_configs_registered_with_chat_shape(monkeypatch): - """``register_pricing_from_global_configs`` walks - ``GLOBAL_VISION_LLM_CONFIGS`` in addition to the chat configs so vision - calls (during indexing) bill correctly. Vision configs use the same - chat-shape token prices, but image-gen pricing is intentionally NOT - registered here (handled via ``response_cost`` in LiteLLM). - """ - from app.config import config - from app.services.pricing_registration import register_pricing_from_global_configs - - spy = _patch_register(monkeypatch) - _patch_openrouter_pricing( - monkeypatch, - {"openai/gpt-4o": {"prompt": "0.000005", "completion": "0.000015"}}, - ) - - # No chat configs — only vision. Proves the vision walk is a separate - # iteration, not piggy-backed on the chat list. - monkeypatch.setattr(config, "GLOBAL_LLM_CONFIGS", []) - monkeypatch.setattr( - config, - "GLOBAL_VISION_LLM_CONFIGS", - [ - { - "id": -1, - "litellm_provider": "openrouter", - "model_name": "openai/gpt-4o", - "billing_tier": "premium", - "input_cost_per_token": 5e-6, - "output_cost_per_token": 15e-6, - } - ], - ) - - register_pricing_from_global_configs() - - assert "openrouter/openai/gpt-4o" in spy.all_keys - payload_value = spy.calls[0]["openrouter/openai/gpt-4o"] - assert payload_value["mode"] == "chat" - assert payload_value["litellm_provider"] == "openrouter" - assert payload_value["input_cost_per_token"] == pytest.approx(5e-6) - assert payload_value["output_cost_per_token"] == pytest.approx(15e-6) - - -def test_vision_with_inline_pricing_when_or_cache_missing(monkeypatch): - """If the OpenRouter pricing cache misses a vision model (different - catalogue surface), the vision walk falls back to inline - ``input_cost_per_token``/``output_cost_per_token`` on the cfg itself. - """ - from app.config import config - from app.services.pricing_registration import register_pricing_from_global_configs - - spy = _patch_register(monkeypatch) - _patch_openrouter_pricing(monkeypatch, {}) - - monkeypatch.setattr(config, "GLOBAL_LLM_CONFIGS", []) - monkeypatch.setattr( - config, - "GLOBAL_VISION_LLM_CONFIGS", - [ - { - "id": -1, - "litellm_provider": "openrouter", - "model_name": "google/gemini-2.5-flash", - "billing_tier": "premium", - "input_cost_per_token": 1e-6, - "output_cost_per_token": 4e-6, - } - ], - ) - - register_pricing_from_global_configs() - - assert "openrouter/google/gemini-2.5-flash" in spy.all_keys diff --git a/surfsense_backend/tests/unit/services/test_quality_score.py b/surfsense_backend/tests/unit/services/test_quality_score.py index 369c8b8f3..cb3f7523a 100644 --- a/surfsense_backend/tests/unit/services/test_quality_score.py +++ b/surfsense_backend/tests/unit/services/test_quality_score.py @@ -1,4 +1,4 @@ -"""Unit tests for the Auto (Fastest) quality scoring module.""" +"""Unit tests for the Auto quality scoring module.""" from __future__ import annotations diff --git a/surfsense_backend/tests/unit/services/test_vision_llm_api_base_defense.py b/surfsense_backend/tests/unit/services/test_vision_llm_api_base_defense.py deleted file mode 100644 index 48dfc8e0b..000000000 --- a/surfsense_backend/tests/unit/services/test_vision_llm_api_base_defense.py +++ /dev/null @@ -1,77 +0,0 @@ -"""Vision LLM resolution must pass explicit per-config ``api_base``.""" - -from __future__ import annotations - -from unittest.mock import AsyncMock, MagicMock, patch - -import pytest - -pytestmark = pytest.mark.unit - - -@pytest.mark.asyncio -async def test_get_vision_llm_global_openrouter_sets_api_base(): - """Global negative-ID branch forwards the explicit OpenRouter base.""" - from app.services import llm_service - - cfg = { - "id": -30_001, - "name": "GPT-4o Vision (OpenRouter)", - "litellm_provider": "openrouter", - "model_name": "openai/gpt-4o", - "api_key": "sk-or-test", - "api_base": "https://openrouter.ai/api/v1", - "api_version": None, - "litellm_params": {}, - "billing_tier": "free", - } - - search_space = MagicMock() - search_space.id = 1 - search_space.user_id = "user-x" - search_space.vision_llm_config_id = cfg["id"] - - session = AsyncMock() - scalars = MagicMock() - scalars.first.return_value = search_space - result = MagicMock() - result.scalars.return_value = scalars - session.execute.return_value = result - - captured: dict = {} - - class FakeSanitized: - def __init__(self, **kwargs): - captured.update(kwargs) - - with ( - patch( - "app.services.vision_llm_router_service.get_global_vision_llm_config", - return_value=cfg, - ), - patch( - "app.agents.chat.runtime.llm_config.SanitizedChatLiteLLM", - new=FakeSanitized, - ), - ): - await llm_service.get_vision_llm(session=session, search_space_id=1) - - assert captured.get("api_base") == "https://openrouter.ai/api/v1" - assert captured["model"] == "openrouter/openai/gpt-4o" - - -def test_vision_router_deployment_sets_api_base_when_config_empty(): - """Auto-mode vision router carries explicit api_base into deployments.""" - from app.services.vision_llm_router_service import VisionLLMRouterService - - deployment = VisionLLMRouterService._config_to_deployment( - { - "model_name": "openai/gpt-4o", - "litellm_provider": "openrouter", - "api_key": "sk-or-test", - "api_base": "https://openrouter.ai/api/v1", - } - ) - assert deployment is not None - assert deployment["litellm_params"]["api_base"] == "https://openrouter.ai/api/v1" - assert deployment["litellm_params"]["model"] == "openrouter/openai/gpt-4o" diff --git a/surfsense_evals/README.md b/surfsense_evals/README.md index c755c4de6..e6fc52ca1 100644 --- a/surfsense_evals/README.md +++ b/surfsense_evals/README.md @@ -77,7 +77,7 @@ The walkthrough above is `--scenario head-to-head` (default): both arms answer w | `symmetric-cheap` | `--provider-model` (cheap, text-only) | `--provider-model` (same) | Does pre-extracted image context let a non-vision LLM reason over image-heavy docs? | | `cost-arbitrage` | `--native-arm-model` (vision) | `--provider-model` (cheap) | How close does SurfSense get to a vision-native baseline at a fraction of per-query cost?| -In all three modes the **ingest-time** vision LLM is set on the SearchSpace's `vision_llm_config_id` (auto-picked from the strongest registered global OpenRouter vision config — `claude-sonnet-4.5` > `claude-opus-4.7` > `gpt-5` > `gemini-2.5-pro`, override with `--vision-llm `). What changes is which slug the *answering* models hit per arm. +In all three modes the **ingest-time** vision LLM is set on the SearchSpace's `vision_model_id` (auto-picked from the strongest registered global OpenRouter vision-capable model — `claude-sonnet-4.5` > `claude-opus-4.7` > `gpt-5` > `gemini-2.5-pro`, override with `--vision-llm `). What changes is which slug the *answering* models hit per arm. ### Ingest with vision, evaluate with a non-vision LLM (`symmetric-cheap`) @@ -118,7 +118,7 @@ python -m surfsense_evals report --suite medical Notes: - `cost-arbitrage` requires both `--provider-model` (the cheap SurfSense slug) AND `--native-arm-model `. -- `--vision-llm ` is optional; if omitted the harness queries `GET /api/v1/global-vision-llm-configs` and auto-picks the strongest registered one. Pass `--no-vision-llm-setup` if you want to keep whatever vision config is already attached to the SearchSpace. +- `--vision-llm ` is optional; if omitted the harness queries `GET /api/v1/model-connections/global` and auto-picks the strongest registered vision-capable model. Pass `--no-vision-llm-setup` if you want to keep whatever vision model is already attached to the SearchSpace. - The runner's "looks text-only" warning is suppressed (or relabelled as informational) for `symmetric-cheap` so intentional asymmetry doesn't read as a misconfiguration. - All three scenario fields (`scenario`, `provider_model`, `native_arm_model`, `vision_provider_model`) are persisted to `state.json` and recorded in `run_artifact.extra` + the report header — no need to retrace what was set. diff --git a/surfsense_evals/data/multimodal_doc/runs/2026-05-14T00-53-19Z/parser_compare/run_artifact.json b/surfsense_evals/data/multimodal_doc/runs/2026-05-14T00-53-19Z/parser_compare/run_artifact.json index a4687f64a..b6c59e2bc 100644 --- a/surfsense_evals/data/multimodal_doc/runs/2026-05-14T00-53-19Z/parser_compare/run_artifact.json +++ b/surfsense_evals/data/multimodal_doc/runs/2026-05-14T00-53-19Z/parser_compare/run_artifact.json @@ -9,7 +9,7 @@ "llamacloud_premium_lc", "surfsense_agentic" ], - "agent_llm_id": -5138454, + "chat_model_id": -5138454, "concurrency": 2, "llm_model": "anthropic/claude-sonnet-4.5", "n_pdfs": 30, diff --git a/surfsense_evals/src/surfsense_evals/core/cli.py b/surfsense_evals/src/surfsense_evals/core/cli.py index 3d4d0fd24..17979fba0 100644 --- a/surfsense_evals/src/surfsense_evals/core/cli.py +++ b/surfsense_evals/src/surfsense_evals/core/cli.py @@ -2,7 +2,7 @@ Subcommands: -* ``setup --suite --provider-model [--agent-llm-id ]`` +* ``setup --suite --provider-model [--chat-model-id ]`` * ``teardown --suite `` * ``models list [--provider openrouter] [--grep ]`` * ``suites list`` @@ -18,7 +18,7 @@ publish its own flags. Design choices worth flagging: -* ``setup`` rejects ``agent_llm_id == 0`` (Auto / LiteLLM router) so +* ``setup`` rejects ``chat_model_id == 0`` (Auto / LiteLLM router) so per-question accuracy is reproducible. * ``setup`` validates that the picked LLM config has ``provider == "OPENROUTER"`` and ``model_name == --provider-model`` @@ -59,7 +59,6 @@ if sys.platform == "win32": from . import registry from .auth import CredentialError, acquire_token, client_with_auth from .clients import SearchSpaceClient -from .clients.search_space import LlmPreferences from .config import ( DEFAULT_SCENARIO, SCENARIOS, @@ -111,23 +110,30 @@ class LlmConfigEntry: def from_payload(cls, payload: dict[str, Any]) -> LlmConfigEntry: return cls( id=int(payload["id"]), - name=str(payload.get("name", "")), + name=str(payload.get("display_name") or payload.get("name") or ""), provider=str(payload.get("provider", "")).upper(), - model_name=str(payload.get("model_name", "")), + model_name=str(payload.get("model_id") or payload.get("model_name") or ""), raw=payload, ) async def _list_global_llm_configs(http: httpx.AsyncClient, base: str) -> list[LlmConfigEntry]: response = await http.get( - f"{base}/api/v1/global-new-llm-configs", + f"{base}/api/v1/model-connections/global", headers={"Accept": "application/json"}, ) response.raise_for_status() payload = response.json() if not isinstance(payload, list): - raise RuntimeError(f"Unexpected /global-new-llm-configs payload: {payload!r}") - return [LlmConfigEntry.from_payload(item) for item in payload] + raise RuntimeError(f"Unexpected /model-connections/global payload: {payload!r}") + entries: list[LlmConfigEntry] = [] + for connection in payload: + provider = connection.get("provider", "") + for model in connection.get("models") or []: + if not model.get("enabled", True) or not model.get("supports_chat"): + continue + entries.append(LlmConfigEntry.from_payload({**model, "provider": provider})) + return entries def _resolve_openrouter_id( @@ -143,8 +149,8 @@ def _resolve_openrouter_id( * If ``explicit_id`` is given: return it directly. The caller is then expected to GET-validate that the row's ``provider == "OPENROUTER"`` and ``model_name`` matches the slug. - That branch supports positive BYOK ``NewLLMConfig`` rows whose - slugs may overlap with global OpenRouter virtuals. + That branch supports positive BYOK model rows whose slugs may overlap + with global OpenRouter virtuals. * Otherwise: filter to ``provider == "OPENROUTER"`` and ``model_name == provider_model``. Expect exactly one match — raise with a friendly message otherwise. @@ -173,7 +179,7 @@ def _resolve_openrouter_id( listing = "\n".join(f" id={c.id} name={c.name!r}" for c in matches) raise RuntimeError( f"Multiple OpenRouter configs for slug '{provider_model}':\n{listing}\n" - "Pass --agent-llm-id to disambiguate." + "Pass --chat-model-id to disambiguate." ) return matches[0].id @@ -186,7 +192,7 @@ def _resolve_openrouter_id( async def _cmd_setup(args: argparse.Namespace) -> int: suite = args.suite provider_model: str = args.provider_model - explicit_id: int | None = args.agent_llm_id + explicit_id: int | None = args.chat_model_id scenario: str = args.scenario vision_llm_slug: str | None = args.vision_llm native_arm_model: str | None = args.native_arm_model @@ -194,7 +200,7 @@ async def _cmd_setup(args: argparse.Namespace) -> int: if explicit_id == 0: console.print( - "[red]agent_llm_id == 0 (Auto / LiteLLM router) is not allowed — " + "[red]chat_model_id == 0 (Auto / LiteLLM router) is not allowed — " "results would not be reproducible.[/red]" ) return 2 @@ -242,7 +248,7 @@ async def _cmd_setup(args: argparse.Namespace) -> int: candidates = await _list_global_llm_configs(http, config.surfsense_api_base) try: - agent_llm_id = _resolve_openrouter_id( + chat_model_id = _resolve_openrouter_id( candidates, provider_model, explicit_id=explicit_id ) except RuntimeError as exc: @@ -288,7 +294,7 @@ async def _cmd_setup(args: argparse.Namespace) -> int: vision_provider_model: str | None = None if not skip_vision_setup and (vision_required or vision_llm_slug is not None): try: - vision_candidates = await ss_client.list_global_vision_llm_configs() + vision_candidates = await ss_client.list_global_vision_models() resolved = resolve_vision_llm( vision_candidates, explicit_slug=vision_llm_slug ) @@ -302,37 +308,34 @@ async def _cmd_setup(args: argparse.Namespace) -> int: f"(id={vision_config_id}, selected_via={resolved.selected_via})." ) - pref_kwargs: dict[str, Any] = {"agent_llm_id": agent_llm_id} + role_kwargs: dict[str, Any] = {"chat_model_id": chat_model_id} if vision_config_id is not None: - pref_kwargs["vision_llm_config_id"] = vision_config_id + role_kwargs["vision_model_id"] = vision_config_id - await ss_client.set_llm_preferences(search_space_id, **pref_kwargs) - prefs = await ss_client.get_llm_preferences(search_space_id) - if not _validate_pin(prefs, provider_model): - agent = prefs.agent_llm or {} + await ss_client.set_model_roles(search_space_id, **role_kwargs) + roles = await ss_client.get_model_roles(search_space_id) + if roles.chat_model_id != chat_model_id: console.print( f"[red]LLM pin validation FAILED.[/red] After PUT, " - f"agent_llm.provider={agent.get('provider')!r}, " - f"model_name={agent.get('model_name')!r}; expected " - f"provider=OPENROUTER, model_name={provider_model!r}." + f"chat_model_id={roles.chat_model_id!r}; expected {chat_model_id!r}." ) return 2 - if vision_config_id is not None and prefs.vision_llm_config_id != vision_config_id: + if vision_config_id is not None and roles.vision_model_id != vision_config_id: console.print( f"[red]Vision LLM pin validation FAILED.[/red] After PUT, " - f"vision_llm_config_id={prefs.vision_llm_config_id!r}; " + f"vision_model_id={roles.vision_model_id!r}; " f"expected {vision_config_id!r}." ) return 2 suite_state = SuiteState( search_space_id=search_space_id, - agent_llm_id=agent_llm_id, + chat_model_id=chat_model_id, provider_model=provider_model, created_at=utc_iso_timestamp(), ingestion_maps=existing.ingestion_maps if existing else {}, scenario=scenario, - vision_llm_config_id=vision_config_id, + vision_model_id=vision_config_id, vision_provider_model=vision_provider_model, native_arm_model=native_arm_model, ) @@ -342,7 +345,7 @@ async def _cmd_setup(args: argparse.Namespace) -> int: f"suite={suite!r}", f"scenario={scenario!r}", f"search_space_id={suite_state.search_space_id}", - f"agent_llm_id={suite_state.agent_llm_id}", + f"chat_model_id={suite_state.chat_model_id}", f"provider_model={suite_state.provider_model!r}", ] if suite_state.vision_provider_model: @@ -353,14 +356,6 @@ async def _cmd_setup(args: argparse.Namespace) -> int: return 0 -def _validate_pin(prefs: LlmPreferences, provider_model: str) -> bool: - agent = prefs.agent_llm or {} - return ( - str(agent.get("provider", "")).upper() == "OPENROUTER" - and str(agent.get("model_name", "")) == provider_model - ) - - async def _cmd_teardown(args: argparse.Namespace) -> int: suite = args.suite config = load_config() @@ -654,10 +649,10 @@ def _build_parser() -> argparse.ArgumentParser: ), ) p_setup.add_argument( - "--agent-llm-id", + "--chat-model-id", type=int, default=None, - help="Optional override for BYOK NewLLMConfig rows.", + help="Optional explicit model id override.", ) p_setup.add_argument( "--scenario", diff --git a/surfsense_evals/src/surfsense_evals/core/clients/search_space.py b/surfsense_evals/src/surfsense_evals/core/clients/search_space.py index e2d37694d..efd4a571d 100644 --- a/surfsense_evals/src/surfsense_evals/core/clients/search_space.py +++ b/surfsense_evals/src/surfsense_evals/core/clients/search_space.py @@ -1,17 +1,16 @@ -"""Client for ``/api/v1/searchspaces`` and ``/api/v1/search-spaces/{id}/llm-preferences``. +"""Client for ``/api/v1/searchspaces`` and model-role endpoints. Verified against: * ``surfsense_backend/app/routes/search_spaces_routes.py:116`` (POST create) * ``surfsense_backend/app/routes/search_spaces_routes.py:234`` (GET by id) * ``surfsense_backend/app/routes/search_spaces_routes.py:422`` (DELETE soft-delete) -* ``surfsense_backend/app/routes/search_spaces_routes.py:698-849`` (GET/PUT llm-preferences) +* ``surfsense_backend/app/routes/model_connections_routes.py`` (GET/PUT model roles) * ``surfsense_backend/app/schemas/search_space.py:14`` (SearchSpaceCreate body) -* ``surfsense_backend/app/routes/vision_llm_routes.py:60`` (GET global vision configs) Note the inconsistent pluralisation in the backend: ``/searchspaces`` -(no hyphen) for CRUD, but ``/search-spaces`` (hyphenated) for the -``llm-preferences`` sub-resource. Both are mirrored verbatim here. +(no hyphen) for CRUD, but ``/search-spaces`` (hyphenated) for model-role +sub-resources. Both are mirrored verbatim here. """ from __future__ import annotations @@ -46,13 +45,8 @@ class SearchSpaceRow: @dataclass -class VisionLlmConfigEntry: - """Subset of one ``GET /global-vision-llm-configs`` row. - - The backend returns negative ids for global / OpenRouter-derived - vision configs and positive ids for per-user BYOK rows. Either is - accepted by ``set_llm_preferences(vision_llm_config_id=...)``. - """ +class VisionModelEntry: + """Subset of one GLOBAL model-connection model with image input support.""" id: int name: str @@ -62,45 +56,38 @@ class VisionLlmConfigEntry: raw: dict[str, Any] @classmethod - def from_payload(cls, payload: dict[str, Any]) -> VisionLlmConfigEntry: + def from_payload(cls, payload: dict[str, Any]) -> VisionModelEntry: return cls( id=int(payload.get("id", 0)), - name=str(payload.get("name", "")), + name=str(payload.get("display_name") or payload.get("model_id") or ""), provider=str(payload.get("provider", "")).upper(), - model_name=str(payload.get("model_name", "")), - is_auto_mode=bool(payload.get("is_auto_mode", False)), + model_name=str(payload.get("model_id", "")), + is_auto_mode=False, raw=payload, ) @dataclass -class LlmPreferences: - """Resolved LLM preferences with the embedded full config row. +class ModelRoles: + """Model role ids for a search space.""" - Mirrors ``LLMPreferencesRead`` from the backend so the lifecycle - command can introspect ``provider`` / ``model_name`` to validate the - OpenRouter pin. - """ - - agent_llm_id: int | None - image_generation_config_id: int | None - vision_llm_config_id: int | None - agent_llm: dict[str, Any] | None + chat_model_id: int | None + image_gen_model_id: int | None + vision_model_id: int | None raw: dict[str, Any] @classmethod - def from_payload(cls, payload: dict[str, Any]) -> LlmPreferences: + def from_payload(cls, payload: dict[str, Any]) -> ModelRoles: return cls( - agent_llm_id=payload.get("agent_llm_id"), - image_generation_config_id=payload.get("image_generation_config_id"), - vision_llm_config_id=payload.get("vision_llm_config_id"), - agent_llm=payload.get("agent_llm"), + chat_model_id=payload.get("chat_model_id"), + image_gen_model_id=payload.get("image_gen_model_id"), + vision_model_id=payload.get("vision_model_id"), raw=payload, ) class SearchSpaceClient: - """Thin wrapper around the SearchSpace + LLM preferences endpoints.""" + """Thin wrapper around the SearchSpace + model role endpoints.""" def __init__(self, http: httpx.AsyncClient, base_url: str) -> None: self._http = http @@ -139,64 +126,67 @@ class SearchSpaceClient: return response.raise_for_status() - async def get_llm_preferences(self, search_space_id: int) -> LlmPreferences: + async def get_model_roles(self, search_space_id: int) -> ModelRoles: response = await self._http.get( - f"{self._base}/api/v1/search-spaces/{search_space_id}/llm-preferences", + f"{self._base}/api/v1/search-spaces/{search_space_id}/model-roles", headers={"Accept": "application/json"}, ) response.raise_for_status() - return LlmPreferences.from_payload(response.json()) + return ModelRoles.from_payload(response.json()) - async def set_llm_preferences( + async def set_model_roles( self, search_space_id: int, *, - agent_llm_id: int | None = None, - image_generation_config_id: int | None = None, - vision_llm_config_id: int | None = None, - ) -> LlmPreferences: - """PUT a partial update to ``/search-spaces/{id}/llm-preferences``. + chat_model_id: int | None = None, + image_gen_model_id: int | None = None, + vision_model_id: int | None = None, + ) -> ModelRoles: + """PUT a partial update to ``/search-spaces/{id}/model-roles``. Backend uses ``model_dump(exclude_unset=True)`` so omitted fields are left unchanged. """ body: dict[str, Any] = {} - if agent_llm_id is not None: - body["agent_llm_id"] = agent_llm_id - if image_generation_config_id is not None: - body["image_generation_config_id"] = image_generation_config_id - if vision_llm_config_id is not None: - body["vision_llm_config_id"] = vision_llm_config_id + if chat_model_id is not None: + body["chat_model_id"] = chat_model_id + if image_gen_model_id is not None: + body["image_gen_model_id"] = image_gen_model_id + if vision_model_id is not None: + body["vision_model_id"] = vision_model_id response = await self._http.put( - f"{self._base}/api/v1/search-spaces/{search_space_id}/llm-preferences", + f"{self._base}/api/v1/search-spaces/{search_space_id}/model-roles", json=body, headers={"Accept": "application/json"}, ) response.raise_for_status() - return LlmPreferences.from_payload(response.json()) + return ModelRoles.from_payload(response.json()) - async def list_global_vision_llm_configs(self) -> list[VisionLlmConfigEntry]: - """List the registered global vision LLM configs. + async def list_global_vision_models(self) -> list[VisionModelEntry]: + """List registered GLOBAL models that can accept image input. - Used by ``setup`` to (a) resolve an explicit ``--vision-llm `` - to a config id and (b) auto-pick the strongest registered vision - config when the operator doesn't pass one. The ``Auto (Fastest)`` - entry (``id=0``) is filtered out — accuracy must be reproducible. + Used by ``setup`` to resolve ``--vision-llm `` or auto-pick a + reproducible ingest-time vision model. """ response = await self._http.get( - f"{self._base}/api/v1/global-vision-llm-configs", + f"{self._base}/api/v1/model-connections/global", headers={"Accept": "application/json"}, ) response.raise_for_status() payload = response.json() if not isinstance(payload, list): raise RuntimeError( - f"Unexpected /global-vision-llm-configs payload: {payload!r}" + f"Unexpected /model-connections/global payload: {payload!r}" ) - return [ - VisionLlmConfigEntry.from_payload(item) - for item in payload - if not bool(item.get("is_auto_mode", False)) - ] + entries: list[VisionModelEntry] = [] + for connection in payload: + provider = str(connection.get("provider", "")) + for model in connection.get("models") or []: + if not model.get("enabled", True) or not model.get("supports_image_input"): + continue + entries.append( + VisionModelEntry.from_payload({**model, "provider": provider}) + ) + return entries diff --git a/surfsense_evals/src/surfsense_evals/core/config.py b/surfsense_evals/src/surfsense_evals/core/config.py index 164955914..9a5a71e89 100644 --- a/surfsense_evals/src/surfsense_evals/core/config.py +++ b/surfsense_evals/src/surfsense_evals/core/config.py @@ -147,35 +147,35 @@ class SuiteState: """Per-suite persisted state. ``provider_model`` is the slug pinned to the SearchSpace's - ``agent_llm`` — what answers SurfSense queries (and what the native + ``chat_model_id`` — what answers SurfSense queries (and what the native arm uses too, unless ``native_arm_model`` is set for cost-arbitrage). - ``vision_provider_model`` is the slug of the OpenRouter vision LLM - config attached to the SearchSpace's ``vision_llm_config_id`` — what + ``vision_provider_model`` is the slug of the OpenRouter vision model + attached to the SearchSpace's ``vision_model_id`` — what SurfSense uses to extract image content at ingest time when ``use_vision_llm=True``. ``None`` means no vision config was attached at setup (legacy or text-only suite). """ search_space_id: int - agent_llm_id: int + chat_model_id: int provider_model: str created_at: str ingestion_maps: dict[str, str] = field(default_factory=dict) scenario: str = DEFAULT_SCENARIO - vision_llm_config_id: int | None = None + vision_model_id: int | None = None vision_provider_model: str | None = None native_arm_model: str | None = None def to_dict(self) -> dict[str, Any]: return { "search_space_id": self.search_space_id, - "agent_llm_id": self.agent_llm_id, + "chat_model_id": self.chat_model_id, "provider_model": self.provider_model, "created_at": self.created_at, "ingestion_maps": dict(self.ingestion_maps), "scenario": self.scenario, - "vision_llm_config_id": self.vision_llm_config_id, + "vision_model_id": self.vision_model_id, "vision_provider_model": self.vision_provider_model, "native_arm_model": self.native_arm_model, } @@ -187,15 +187,16 @@ class SuiteState: scenario = str(payload.get("scenario") or DEFAULT_SCENARIO) if scenario not in SCENARIOS: scenario = DEFAULT_SCENARIO - raw_vision_id = payload.get("vision_llm_config_id") + raw_chat_id = payload.get("chat_model_id") + raw_vision_id = payload.get("vision_model_id") return cls( search_space_id=int(payload["search_space_id"]), - agent_llm_id=int(payload["agent_llm_id"]), + chat_model_id=int(raw_chat_id), provider_model=str(payload["provider_model"]), created_at=str(payload.get("created_at") or ""), ingestion_maps=dict(payload.get("ingestion_maps") or {}), scenario=scenario, - vision_llm_config_id=int(raw_vision_id) if raw_vision_id is not None else None, + vision_model_id=int(raw_vision_id) if raw_vision_id is not None else None, vision_provider_model=( str(payload["vision_provider_model"]) if payload.get("vision_provider_model") diff --git a/surfsense_evals/src/surfsense_evals/core/registry.py b/surfsense_evals/src/surfsense_evals/core/registry.py index cc8b725e0..65f64c39a 100644 --- a/surfsense_evals/src/surfsense_evals/core/registry.py +++ b/surfsense_evals/src/surfsense_evals/core/registry.py @@ -53,8 +53,8 @@ class RunContext: return self.suite_state.search_space_id @property - def agent_llm_id(self) -> int: - return self.suite_state.agent_llm_id + def chat_model_id(self) -> int: + return self.suite_state.chat_model_id @property def provider_model(self) -> str: diff --git a/surfsense_evals/src/surfsense_evals/core/vision_llm.py b/surfsense_evals/src/surfsense_evals/core/vision_llm.py index ae96f1285..5d5e2c6d1 100644 --- a/surfsense_evals/src/surfsense_evals/core/vision_llm.py +++ b/surfsense_evals/src/surfsense_evals/core/vision_llm.py @@ -3,8 +3,8 @@ Two responsibilities: 1. Resolve an explicit ``--vision-llm `` to a global OpenRouter - vision LLM config id that ``set_llm_preferences(vision_llm_config_id=...)`` - can accept. + vision-capable model id that ``set_model_roles(vision_model_id=...)`` can + accept. 2. Auto-pick the strongest registered vision config when the operator doesn't pass ``--vision-llm`` but the scenario / benchmark needs one. diff --git a/surfsense_evals/src/surfsense_evals/suites/medical/medxpertqa/runner.py b/surfsense_evals/src/surfsense_evals/suites/medical/medxpertqa/runner.py index e1a830138..ac0651996 100644 --- a/surfsense_evals/src/surfsense_evals/suites/medical/medxpertqa/runner.py +++ b/surfsense_evals/src/surfsense_evals/suites/medical/medxpertqa/runner.py @@ -371,7 +371,7 @@ class MedXpertQAMMBenchmark: "provider_model": ctx.provider_model, "native_arm_model": native_arm_model, "vision_provider_model": ctx.vision_provider_model, - "agent_llm_id": ctx.agent_llm_id, + "chat_model_id": ctx.chat_model_id, "ingest_settings": ingest_settings, }, ) diff --git a/surfsense_evals/src/surfsense_evals/suites/multimodal_doc/mmlongbench/runner.py b/surfsense_evals/src/surfsense_evals/suites/multimodal_doc/mmlongbench/runner.py index 95a1e15eb..b7685766e 100644 --- a/surfsense_evals/src/surfsense_evals/suites/multimodal_doc/mmlongbench/runner.py +++ b/surfsense_evals/src/surfsense_evals/suites/multimodal_doc/mmlongbench/runner.py @@ -391,7 +391,7 @@ class MMLongBenchDocBenchmark: "provider_model": ctx.provider_model, "native_arm_model": native_arm_model, "vision_provider_model": ctx.vision_provider_model, - "agent_llm_id": ctx.agent_llm_id, + "chat_model_id": ctx.chat_model_id, "ingest_settings": ingest_settings, }, ) diff --git a/surfsense_evals/src/surfsense_evals/suites/multimodal_doc/parser_compare/runner.py b/surfsense_evals/src/surfsense_evals/suites/multimodal_doc/parser_compare/runner.py index e71dffa65..2c4a0ffe4 100644 --- a/surfsense_evals/src/surfsense_evals/suites/multimodal_doc/parser_compare/runner.py +++ b/surfsense_evals/src/surfsense_evals/suites/multimodal_doc/parser_compare/runner.py @@ -554,7 +554,7 @@ class ParserCompareBenchmark: "scenario": ctx.scenario, "provider_model": ctx.provider_model, "vision_provider_model": ctx.vision_provider_model, - "agent_llm_id": ctx.agent_llm_id, + "chat_model_id": ctx.chat_model_id, "preprocess_tariff": { "basic_per_1k_pages": 1.0, "premium_per_1k_pages": 10.0, diff --git a/surfsense_evals/src/surfsense_evals/suites/research/crag/runner.py b/surfsense_evals/src/surfsense_evals/suites/research/crag/runner.py index 8b759e0d8..654c261a2 100644 --- a/surfsense_evals/src/surfsense_evals/suites/research/crag/runner.py +++ b/surfsense_evals/src/surfsense_evals/suites/research/crag/runner.py @@ -467,7 +467,7 @@ class CragBenchmark: "provider_model": ctx.provider_model, "native_arm_model": ctx.native_arm_model, "vision_provider_model": ctx.vision_provider_model, - "agent_llm_id": ctx.agent_llm_id, + "chat_model_id": ctx.chat_model_id, "ingest_settings": ingest_settings, "per_page_char_cap": per_page_char_cap, "max_output_tokens": max_output_tokens, diff --git a/surfsense_evals/src/surfsense_evals/suites/research/frames/runner.py b/surfsense_evals/src/surfsense_evals/suites/research/frames/runner.py index 9c0e16b00..450c7ddd6 100644 --- a/surfsense_evals/src/surfsense_evals/suites/research/frames/runner.py +++ b/surfsense_evals/src/surfsense_evals/suites/research/frames/runner.py @@ -372,7 +372,7 @@ class FramesBenchmark: "provider_model": ctx.provider_model, "native_arm_model": ctx.native_arm_model, "vision_provider_model": ctx.vision_provider_model, - "agent_llm_id": ctx.agent_llm_id, + "chat_model_id": ctx.chat_model_id, "ingest_settings": ingest_settings, "bare_arm_label": "bare_llm", }, diff --git a/surfsense_evals/tests/core/test_clients.py b/surfsense_evals/tests/core/test_clients.py index 611408703..aa98f0ad4 100644 --- a/surfsense_evals/tests/core/test_clients.py +++ b/surfsense_evals/tests/core/test_clients.py @@ -63,29 +63,22 @@ async def test_delete_search_space_idempotent_on_404(respx_mock, http): @pytest.mark.asyncio @respx.mock(base_url=_BASE) -async def test_set_llm_preferences_partial_update(respx_mock, http): - route = respx_mock.put("/api/v1/search-spaces/42/llm-preferences").mock( +async def test_set_model_roles_partial_update(respx_mock, http): + route = respx_mock.put("/api/v1/search-spaces/42/model-roles").mock( return_value=httpx.Response( 200, json={ - "agent_llm_id": -10042, - "agent_llm_id": None, - "image_generation_config_id": None, - "vision_llm_config_id": None, - "agent_llm": { - "id": -10042, - "provider": "OPENROUTER", - "model_name": "anthropic/claude-sonnet-4.5", - }, + "chat_model_id": -10042, + "image_gen_model_id": None, + "vision_model_id": None, }, ) ) client = SearchSpaceClient(http, _BASE) - prefs = await client.set_llm_preferences(42, agent_llm_id=-10042) - assert prefs.agent_llm_id == -10042 - assert prefs.agent_llm["provider"] == "OPENROUTER" + roles = await client.set_model_roles(42, chat_model_id=-10042) + assert roles.chat_model_id == -10042 sent_body = json.loads(route.calls[-1].request.content) - assert sent_body == {"agent_llm_id": -10042} + assert sent_body == {"chat_model_id": -10042} # --------------------------------------------------------------------------- diff --git a/surfsense_evals/tests/core/test_config.py b/surfsense_evals/tests/core/test_config.py index f7b8f7249..6f9671c86 100644 --- a/surfsense_evals/tests/core/test_config.py +++ b/surfsense_evals/tests/core/test_config.py @@ -41,14 +41,14 @@ def test_state_roundtrip_per_suite(tmp_env): # noqa: ARG001 assert get_suite_state(config, "medical") is None state = SuiteState( search_space_id=1, - agent_llm_id=-10042, + chat_model_id=-10042, provider_model="anthropic/claude-sonnet-4.5", created_at="2026-05-11T20-30-00Z", ) set_suite_state(config, "medical", state) legal = SuiteState( search_space_id=2, - agent_llm_id=-1, + chat_model_id=-1, provider_model="openai/gpt-5", created_at="2026-05-11T21-00-00Z", ) @@ -84,25 +84,19 @@ def test_paths_are_per_suite(tmp_env): # noqa: ARG001 # --------------------------------------------------------------------------- -def test_legacy_state_back_compat_defaults_to_head_to_head(): - """state.json files written before scenarios shipped must still load. +def test_minimal_state_defaults_to_head_to_head(): + """Missing scenario / vision / native fields default safely.""" - Missing ``scenario`` / ``vision_*`` / ``native_arm_model`` keys all - default to ``head-to-head`` / ``None`` so old setups keep working - after upgrade — the runner's behaviour exactly mirrors the legacy - one (both arms answer with ``provider_model``). - """ - - legacy = { + payload = { "search_space_id": 7, - "agent_llm_id": -123, + "chat_model_id": -123, "provider_model": "anthropic/claude-sonnet-4.5", "created_at": "2026-05-11T20-30-00Z", "ingestion_maps": {}, } - state = SuiteState.from_dict(legacy) + state = SuiteState.from_dict(payload) assert state.scenario == DEFAULT_SCENARIO == "head-to-head" - assert state.vision_llm_config_id is None + assert state.vision_model_id is None assert state.vision_provider_model is None assert state.native_arm_model is None # The native arm should still answer with the same slug as SurfSense. @@ -118,7 +112,7 @@ def test_unknown_scenario_falls_back_to_default(): payload = { "search_space_id": 1, - "agent_llm_id": -1, + "chat_model_id": -1, "provider_model": "openai/gpt-5", "scenario": "unknown-scenario-name", } @@ -130,11 +124,11 @@ def test_cost_arbitrage_state_persists_native_arm_model(tmp_env): # noqa: ARG00 config = load_config() state = SuiteState( search_space_id=42, - agent_llm_id=-1, + chat_model_id=-1, provider_model="openai/gpt-5.4-mini", created_at="2026-05-11T20-30-00Z", scenario="cost-arbitrage", - vision_llm_config_id=-101, + vision_model_id=-101, vision_provider_model="anthropic/claude-sonnet-4.5", native_arm_model="anthropic/claude-sonnet-4.5", ) @@ -142,7 +136,7 @@ def test_cost_arbitrage_state_persists_native_arm_model(tmp_env): # noqa: ARG00 fetched = get_suite_state(config, "medical") assert fetched.scenario == "cost-arbitrage" - assert fetched.vision_llm_config_id == -101 + assert fetched.vision_model_id == -101 assert fetched.vision_provider_model == "anthropic/claude-sonnet-4.5" assert fetched.native_arm_model == "anthropic/claude-sonnet-4.5" # Cost arbitrage's whole point: native arm slug != surfsense slug. diff --git a/surfsense_evals/tests/test_integration_smoke.py b/surfsense_evals/tests/test_integration_smoke.py index 493c04c25..1c89ae5ab 100644 --- a/surfsense_evals/tests/test_integration_smoke.py +++ b/surfsense_evals/tests/test_integration_smoke.py @@ -27,7 +27,7 @@ async def test_smoke_against_localhost(): pytest.skip("No credentials in environment; skipping integration smoke") bundle = await acquire_token(config) async with client_with_auth(config, bundle) as client: - response = await client.get(f"{config.surfsense_api_base}/api/v1/global-new-llm-configs") + response = await client.get(f"{config.surfsense_api_base}/api/v1/model-connections/global") try: response.raise_for_status() except httpx.HTTPStatusError as exc: diff --git a/surfsense_web/app/dashboard/[search_space_id]/search-space-settings/image-models/page.tsx b/surfsense_web/app/dashboard/[search_space_id]/search-space-settings/image-models/page.tsx deleted file mode 100644 index b300f8078..000000000 --- a/surfsense_web/app/dashboard/[search_space_id]/search-space-settings/image-models/page.tsx +++ /dev/null @@ -1,6 +0,0 @@ -import { ImageModelManager } from "@/components/settings/image-model-manager"; - -export default async function Page({ params }: { params: Promise<{ search_space_id: string }> }) { - const { search_space_id } = await params; - return ; -} diff --git a/surfsense_web/app/dashboard/[search_space_id]/search-space-settings/roles/page.tsx b/surfsense_web/app/dashboard/[search_space_id]/search-space-settings/roles/page.tsx deleted file mode 100644 index 5bad50cd3..000000000 --- a/surfsense_web/app/dashboard/[search_space_id]/search-space-settings/roles/page.tsx +++ /dev/null @@ -1,6 +0,0 @@ -import { LLMRoleManager } from "@/components/settings/llm-role-manager"; - -export default async function Page({ params }: { params: Promise<{ search_space_id: string }> }) { - const { search_space_id } = await params; - return ; -} diff --git a/surfsense_web/app/dashboard/[search_space_id]/search-space-settings/vision-models/page.tsx b/surfsense_web/app/dashboard/[search_space_id]/search-space-settings/vision-models/page.tsx deleted file mode 100644 index 06aea003a..000000000 --- a/surfsense_web/app/dashboard/[search_space_id]/search-space-settings/vision-models/page.tsx +++ /dev/null @@ -1,6 +0,0 @@ -import { VisionModelManager } from "@/components/settings/vision-model-manager"; - -export default async function Page({ params }: { params: Promise<{ search_space_id: string }> }) { - const { search_space_id } = await params; - return ; -} diff --git a/surfsense_web/atoms/image-gen-config/image-gen-config-mutation.atoms.ts b/surfsense_web/atoms/image-gen-config/image-gen-config-mutation.atoms.ts deleted file mode 100644 index 922c398c9..000000000 --- a/surfsense_web/atoms/image-gen-config/image-gen-config-mutation.atoms.ts +++ /dev/null @@ -1,96 +0,0 @@ -import { atomWithMutation } from "jotai-tanstack-query"; -import { toast } from "sonner"; -import type { - CreateImageGenConfigRequest, - CreateImageGenConfigResponse, - DeleteImageGenConfigResponse, - GetImageGenConfigsResponse, - UpdateImageGenConfigRequest, - UpdateImageGenConfigResponse, -} from "@/contracts/types/new-llm-config.types"; -import { imageGenConfigApiService } from "@/lib/apis/image-gen-config-api.service"; -import { cacheKeys } from "@/lib/query-client/cache-keys"; -import { queryClient } from "@/lib/query-client/client"; -import { activeSearchSpaceIdAtom } from "../search-spaces/search-space-query.atoms"; - -/** - * Mutation atom for creating a new ImageGenerationConfig - */ -export const createImageGenConfigMutationAtom = atomWithMutation((get) => { - const searchSpaceId = get(activeSearchSpaceIdAtom); - - return { - mutationKey: ["image-gen-configs", "create"], - meta: { suppressGlobalErrorToast: true }, - enabled: !!searchSpaceId, - mutationFn: async (request: CreateImageGenConfigRequest) => { - return imageGenConfigApiService.createConfig(request); - }, - onSuccess: (_: CreateImageGenConfigResponse, request: CreateImageGenConfigRequest) => { - toast.success(`${request.name} created`); - queryClient.invalidateQueries({ - queryKey: cacheKeys.imageGenConfigs.all(Number(searchSpaceId)), - }); - }, - onError: (error: Error) => { - toast.error(error.message || "Failed to create image model"); - }, - }; -}); - -/** - * Mutation atom for updating an existing ImageGenerationConfig - */ -export const updateImageGenConfigMutationAtom = atomWithMutation((get) => { - const searchSpaceId = get(activeSearchSpaceIdAtom); - - return { - mutationKey: ["image-gen-configs", "update"], - meta: { suppressGlobalErrorToast: true }, - enabled: !!searchSpaceId, - mutationFn: async (request: UpdateImageGenConfigRequest) => { - return imageGenConfigApiService.updateConfig(request); - }, - onSuccess: (_: UpdateImageGenConfigResponse, request: UpdateImageGenConfigRequest) => { - toast.success(`${request.data.name ?? "Configuration"} updated`); - queryClient.invalidateQueries({ - queryKey: cacheKeys.imageGenConfigs.all(Number(searchSpaceId)), - }); - queryClient.invalidateQueries({ - queryKey: cacheKeys.imageGenConfigs.byId(request.id), - }); - }, - onError: (error: Error) => { - toast.error(error.message || "Failed to update image model"); - }, - }; -}); - -/** - * Mutation atom for deleting an ImageGenerationConfig - */ -export const deleteImageGenConfigMutationAtom = atomWithMutation((get) => { - const searchSpaceId = get(activeSearchSpaceIdAtom); - - return { - mutationKey: ["image-gen-configs", "delete"], - meta: { suppressGlobalErrorToast: true }, - enabled: !!searchSpaceId, - mutationFn: async (request: { id: number; name: string }) => { - return imageGenConfigApiService.deleteConfig(request.id); - }, - onSuccess: (_: DeleteImageGenConfigResponse, request: { id: number; name: string }) => { - toast.success(`${request.name} deleted`); - queryClient.setQueryData( - cacheKeys.imageGenConfigs.all(Number(searchSpaceId)), - (oldData: GetImageGenConfigsResponse | undefined) => { - if (!oldData) return oldData; - return oldData.filter((config) => config.id !== request.id); - } - ); - }, - onError: (error: Error) => { - toast.error(error.message || "Failed to delete image model"); - }, - }; -}); diff --git a/surfsense_web/atoms/image-gen-config/image-gen-config-query.atoms.ts b/surfsense_web/atoms/image-gen-config/image-gen-config-query.atoms.ts deleted file mode 100644 index a45e69a03..000000000 --- a/surfsense_web/atoms/image-gen-config/image-gen-config-query.atoms.ts +++ /dev/null @@ -1,33 +0,0 @@ -import { atomWithQuery } from "jotai-tanstack-query"; -import { imageGenConfigApiService } from "@/lib/apis/image-gen-config-api.service"; -import { cacheKeys } from "@/lib/query-client/cache-keys"; -import { activeSearchSpaceIdAtom } from "../search-spaces/search-space-query.atoms"; - -/** - * Query atom for fetching user-created image gen configs for the active search space - */ -export const imageGenConfigsAtom = atomWithQuery((get) => { - const searchSpaceId = get(activeSearchSpaceIdAtom); - - return { - queryKey: cacheKeys.imageGenConfigs.all(Number(searchSpaceId)), - enabled: !!searchSpaceId, - staleTime: 5 * 60 * 1000, // 5 minutes - queryFn: async () => { - return imageGenConfigApiService.getConfigs(Number(searchSpaceId)); - }, - }; -}); - -/** - * Query atom for fetching global image gen configs (from YAML, negative IDs) - */ -export const globalImageGenConfigsAtom = atomWithQuery(() => { - return { - queryKey: cacheKeys.imageGenConfigs.global(), - staleTime: 10 * 60 * 1000, // 10 minutes - global configs rarely change - queryFn: async () => { - return imageGenConfigApiService.getGlobalConfigs(); - }, - }; -}); diff --git a/surfsense_web/atoms/new-llm-config/new-llm-config-mutation.atoms.ts b/surfsense_web/atoms/new-llm-config/new-llm-config-mutation.atoms.ts deleted file mode 100644 index 476d89d4c..000000000 --- a/surfsense_web/atoms/new-llm-config/new-llm-config-mutation.atoms.ts +++ /dev/null @@ -1,132 +0,0 @@ -import { atomWithMutation } from "jotai-tanstack-query"; -import { toast } from "sonner"; -import type { - CreateNewLLMConfigRequest, - CreateNewLLMConfigResponse, - DeleteNewLLMConfigRequest, - DeleteNewLLMConfigResponse, - GetNewLLMConfigsResponse, - UpdateLLMPreferencesRequest, - UpdateNewLLMConfigRequest, - UpdateNewLLMConfigResponse, -} from "@/contracts/types/new-llm-config.types"; -import { newLLMConfigApiService } from "@/lib/apis/new-llm-config-api.service"; -import { cacheKeys } from "@/lib/query-client/cache-keys"; -import { queryClient } from "@/lib/query-client/client"; -import { activeSearchSpaceIdAtom } from "../search-spaces/search-space-query.atoms"; - -/** - * Mutation atom for creating a new NewLLMConfig - */ -export const createNewLLMConfigMutationAtom = atomWithMutation((get) => { - const searchSpaceId = get(activeSearchSpaceIdAtom); - - return { - mutationKey: ["new-llm-configs", "create"], - meta: { suppressGlobalErrorToast: true }, - enabled: !!searchSpaceId, - mutationFn: async (request: CreateNewLLMConfigRequest) => { - return newLLMConfigApiService.createConfig(request); - }, - onSuccess: (_: CreateNewLLMConfigResponse, request: CreateNewLLMConfigRequest) => { - toast.success(`${request.name} created`); - queryClient.invalidateQueries({ - queryKey: cacheKeys.newLLMConfigs.all(Number(searchSpaceId)), - }); - }, - onError: (error: Error) => { - toast.error(error.message || "Failed to create model"); - }, - }; -}); - -/** - * Mutation atom for updating an existing NewLLMConfig - */ -export const updateNewLLMConfigMutationAtom = atomWithMutation((get) => { - const searchSpaceId = get(activeSearchSpaceIdAtom); - - return { - mutationKey: ["new-llm-configs", "update"], - meta: { suppressGlobalErrorToast: true }, - enabled: !!searchSpaceId, - mutationFn: async (request: UpdateNewLLMConfigRequest) => { - return newLLMConfigApiService.updateConfig(request); - }, - onSuccess: (_: UpdateNewLLMConfigResponse, request: UpdateNewLLMConfigRequest) => { - toast.success(`${request.data.name ?? "Configuration"} updated`); - queryClient.invalidateQueries({ - queryKey: cacheKeys.newLLMConfigs.all(Number(searchSpaceId)), - }); - queryClient.invalidateQueries({ - queryKey: cacheKeys.newLLMConfigs.byId(request.id), - }); - }, - onError: (error: Error) => { - toast.error(error.message || "Failed to update"); - }, - }; -}); - -/** - * Mutation atom for deleting a NewLLMConfig - */ -export const deleteNewLLMConfigMutationAtom = atomWithMutation((get) => { - const searchSpaceId = get(activeSearchSpaceIdAtom); - - return { - mutationKey: ["new-llm-configs", "delete"], - meta: { suppressGlobalErrorToast: true }, - enabled: !!searchSpaceId, - mutationFn: async (request: DeleteNewLLMConfigRequest & { name: string }) => { - return newLLMConfigApiService.deleteConfig({ id: request.id }); - }, - onSuccess: ( - _: DeleteNewLLMConfigResponse, - request: DeleteNewLLMConfigRequest & { name: string } - ) => { - toast.success(`${request.name} deleted`); - queryClient.setQueryData( - cacheKeys.newLLMConfigs.all(Number(searchSpaceId)), - (oldData: GetNewLLMConfigsResponse | undefined) => { - if (!oldData) return oldData; - return oldData.filter((config) => config.id !== request.id); - } - ); - }, - onError: (error: Error) => { - toast.error(error.message || "Failed to delete"); - }, - }; -}); - -/** - * Mutation atom for updating LLM preferences (role assignments) - */ -export const updateLLMPreferencesMutationAtom = atomWithMutation((get) => { - const searchSpaceId = get(activeSearchSpaceIdAtom); - - return { - mutationKey: ["llm-preferences", "update"], - meta: { suppressGlobalErrorToast: true }, - enabled: !!searchSpaceId, - mutationFn: async (request: UpdateLLMPreferencesRequest) => { - return newLLMConfigApiService.updateLLMPreferences(request); - }, - onSuccess: (_data, request: UpdateLLMPreferencesRequest) => { - queryClient.setQueryData( - cacheKeys.newLLMConfigs.preferences(Number(searchSpaceId)), - (old: Record | undefined) => ({ ...old, ...request.data }) - ); - // Automation eligibility is derived from these model preferences - // (agent/image/vision). Invalidate it so the automations gate alert - // reflects the new selection without a manual refresh. - queryClient.invalidateQueries({ - queryKey: cacheKeys.automations.modelEligibility(Number(searchSpaceId)), - }); - }, - onError: (error: Error) => { - toast.error(error.message || "Failed to update LLM preferences"); - }, - }; -}); diff --git a/surfsense_web/atoms/new-llm-config/new-llm-config-query.atoms.ts b/surfsense_web/atoms/new-llm-config/new-llm-config-query.atoms.ts deleted file mode 100644 index 410d061e5..000000000 --- a/surfsense_web/atoms/new-llm-config/new-llm-config-query.atoms.ts +++ /dev/null @@ -1,98 +0,0 @@ -import { atomWithQuery } from "jotai-tanstack-query"; -import type { LLMModel } from "@/contracts/enums/llm-models"; -import { LLM_MODELS } from "@/contracts/enums/llm-models"; -import { newLLMConfigApiService } from "@/lib/apis/new-llm-config-api.service"; -import { getBearerToken } from "@/lib/auth-utils"; -import { cacheKeys } from "@/lib/query-client/cache-keys"; -import { activeSearchSpaceIdAtom } from "../search-spaces/search-space-query.atoms"; - -/** - * Query atom for fetching all NewLLMConfigs for the active search space - */ -export const newLLMConfigsAtom = atomWithQuery((get) => { - const searchSpaceId = get(activeSearchSpaceIdAtom); - - return { - queryKey: cacheKeys.newLLMConfigs.all(Number(searchSpaceId)), - enabled: !!searchSpaceId, - staleTime: 5 * 60 * 1000, // 5 minutes - queryFn: async () => { - return newLLMConfigApiService.getConfigs({ - search_space_id: Number(searchSpaceId), - }); - }, - }; -}); - -/** - * Query atom for fetching global NewLLMConfigs (from YAML, negative IDs) - */ -export const globalNewLLMConfigsAtom = atomWithQuery(() => { - return { - queryKey: cacheKeys.newLLMConfigs.global(), - staleTime: 10 * 60 * 1000, // 10 minutes - global configs rarely change - enabled: !!getBearerToken(), - queryFn: async () => { - return newLLMConfigApiService.getGlobalConfigs(); - }, - }; -}); - -/** - * Query atom for fetching LLM preferences (role assignments) for the active search space - */ -export const llmPreferencesAtom = atomWithQuery((get) => { - const searchSpaceId = get(activeSearchSpaceIdAtom); - - return { - queryKey: cacheKeys.newLLMConfigs.preferences(Number(searchSpaceId)), - enabled: !!searchSpaceId, - staleTime: 5 * 60 * 1000, // 5 minutes - queryFn: async () => { - return newLLMConfigApiService.getLLMPreferences(Number(searchSpaceId)); - }, - }; -}); - -/** - * Query atom for fetching default system instructions template - */ -export const defaultSystemInstructionsAtom = atomWithQuery(() => { - return { - queryKey: cacheKeys.newLLMConfigs.defaultInstructions(), - staleTime: 60 * 60 * 1000, // 1 hour - this rarely changes - queryFn: async () => { - return newLLMConfigApiService.getDefaultSystemInstructions(); - }, - }; -}); - -/** - * Query atom for the dynamic model catalogue. - * Fetched from the backend (which proxies OpenRouter's public API). - * Falls back to the static hardcoded list on error. - */ -export const modelListAtom = atomWithQuery(() => { - return { - queryKey: cacheKeys.newLLMConfigs.modelList(), - staleTime: 60 * 60 * 1000, // 1 hour - models don't change often - placeholderData: LLM_MODELS, - queryFn: async (): Promise => { - const data = await newLLMConfigApiService.getModels(); - const dynamicModels = data.map((m) => ({ - value: m.value, - label: m.label, - provider: m.provider, - contextWindow: m.context_window ?? undefined, - })); - - // Providers covered by the dynamic API (from OpenRouter mapping). - // For uncovered providers (Ollama, Groq, Bedrock, etc.) keep the - // hand-curated static suggestions so users still see model options. - const coveredProviders = new Set(dynamicModels.map((m) => m.provider)); - const staticFallbacks = LLM_MODELS.filter((m) => !coveredProviders.has(m.provider)); - - return [...dynamicModels, ...staticFallbacks]; - }, - }; -}); diff --git a/surfsense_web/atoms/vision-llm-config/vision-llm-config-mutation.atoms.ts b/surfsense_web/atoms/vision-llm-config/vision-llm-config-mutation.atoms.ts deleted file mode 100644 index f46b977d5..000000000 --- a/surfsense_web/atoms/vision-llm-config/vision-llm-config-mutation.atoms.ts +++ /dev/null @@ -1,87 +0,0 @@ -import { atomWithMutation } from "jotai-tanstack-query"; -import { toast } from "sonner"; -import type { - CreateVisionLLMConfigRequest, - CreateVisionLLMConfigResponse, - DeleteVisionLLMConfigResponse, - GetVisionLLMConfigsResponse, - UpdateVisionLLMConfigRequest, - UpdateVisionLLMConfigResponse, -} from "@/contracts/types/new-llm-config.types"; -import { visionLLMConfigApiService } from "@/lib/apis/vision-llm-config-api.service"; -import { cacheKeys } from "@/lib/query-client/cache-keys"; -import { queryClient } from "@/lib/query-client/client"; -import { activeSearchSpaceIdAtom } from "../search-spaces/search-space-query.atoms"; - -export const createVisionLLMConfigMutationAtom = atomWithMutation((get) => { - const searchSpaceId = get(activeSearchSpaceIdAtom); - - return { - mutationKey: ["vision-llm-configs", "create"], - meta: { suppressGlobalErrorToast: true }, - enabled: !!searchSpaceId, - mutationFn: async (request: CreateVisionLLMConfigRequest) => { - return visionLLMConfigApiService.createConfig(request); - }, - onSuccess: (_: CreateVisionLLMConfigResponse, request: CreateVisionLLMConfigRequest) => { - toast.success(`${request.name} created`); - queryClient.invalidateQueries({ - queryKey: cacheKeys.visionLLMConfigs.all(Number(searchSpaceId)), - }); - }, - onError: (error: Error) => { - toast.error(error.message || "Failed to create vision model"); - }, - }; -}); - -export const updateVisionLLMConfigMutationAtom = atomWithMutation((get) => { - const searchSpaceId = get(activeSearchSpaceIdAtom); - - return { - mutationKey: ["vision-llm-configs", "update"], - meta: { suppressGlobalErrorToast: true }, - enabled: !!searchSpaceId, - mutationFn: async (request: UpdateVisionLLMConfigRequest) => { - return visionLLMConfigApiService.updateConfig(request); - }, - onSuccess: (_: UpdateVisionLLMConfigResponse, request: UpdateVisionLLMConfigRequest) => { - toast.success(`${request.data.name ?? "Configuration"} updated`); - queryClient.invalidateQueries({ - queryKey: cacheKeys.visionLLMConfigs.all(Number(searchSpaceId)), - }); - queryClient.invalidateQueries({ - queryKey: cacheKeys.visionLLMConfigs.byId(request.id), - }); - }, - onError: (error: Error) => { - toast.error(error.message || "Failed to update vision model"); - }, - }; -}); - -export const deleteVisionLLMConfigMutationAtom = atomWithMutation((get) => { - const searchSpaceId = get(activeSearchSpaceIdAtom); - - return { - mutationKey: ["vision-llm-configs", "delete"], - meta: { suppressGlobalErrorToast: true }, - enabled: !!searchSpaceId, - mutationFn: async (request: { id: number; name: string }) => { - return visionLLMConfigApiService.deleteConfig(request.id); - }, - onSuccess: (_: DeleteVisionLLMConfigResponse, request: { id: number; name: string }) => { - toast.success(`${request.name} deleted`); - queryClient.setQueryData( - cacheKeys.visionLLMConfigs.all(Number(searchSpaceId)), - (oldData: GetVisionLLMConfigsResponse | undefined) => { - if (!oldData) return oldData; - return oldData.filter((config) => config.id !== request.id); - } - ); - }, - onError: (error: Error) => { - toast.error(error.message || "Failed to delete vision model"); - }, - }; -}); diff --git a/surfsense_web/atoms/vision-llm-config/vision-llm-config-query.atoms.ts b/surfsense_web/atoms/vision-llm-config/vision-llm-config-query.atoms.ts deleted file mode 100644 index 906ce638f..000000000 --- a/surfsense_web/atoms/vision-llm-config/vision-llm-config-query.atoms.ts +++ /dev/null @@ -1,51 +0,0 @@ -import { atomWithQuery } from "jotai-tanstack-query"; -import type { LLMModel } from "@/contracts/enums/llm-models"; -import { VISION_MODELS } from "@/contracts/enums/vision-providers"; -import { visionLLMConfigApiService } from "@/lib/apis/vision-llm-config-api.service"; -import { cacheKeys } from "@/lib/query-client/cache-keys"; -import { activeSearchSpaceIdAtom } from "../search-spaces/search-space-query.atoms"; - -export const visionLLMConfigsAtom = atomWithQuery((get) => { - const searchSpaceId = get(activeSearchSpaceIdAtom); - - return { - queryKey: cacheKeys.visionLLMConfigs.all(Number(searchSpaceId)), - enabled: !!searchSpaceId, - staleTime: 5 * 60 * 1000, - queryFn: async () => { - return visionLLMConfigApiService.getConfigs(Number(searchSpaceId)); - }, - }; -}); - -export const globalVisionLLMConfigsAtom = atomWithQuery(() => { - return { - queryKey: cacheKeys.visionLLMConfigs.global(), - staleTime: 10 * 60 * 1000, - queryFn: async () => { - return visionLLMConfigApiService.getGlobalConfigs(); - }, - }; -}); - -export const visionModelListAtom = atomWithQuery(() => { - return { - queryKey: cacheKeys.visionLLMConfigs.modelList(), - staleTime: 60 * 60 * 1000, - placeholderData: VISION_MODELS, - queryFn: async (): Promise => { - const data = await visionLLMConfigApiService.getModels(); - const dynamicModels = data.map((m) => ({ - value: m.value, - label: m.label, - provider: m.provider, - contextWindow: m.context_window ?? undefined, - })); - - const coveredProviders = new Set(dynamicModels.map((m) => m.provider)); - const staticFallbacks = VISION_MODELS.filter((m) => !coveredProviders.has(m.provider)); - - return [...dynamicModels, ...staticFallbacks]; - }, - }; -}); diff --git a/surfsense_web/components/new-chat/chat-header.tsx b/surfsense_web/components/new-chat/chat-header.tsx index 4716418ee..d65dc93a7 100644 --- a/surfsense_web/components/new-chat/chat-header.tsx +++ b/surfsense_web/components/new-chat/chat-header.tsx @@ -1,17 +1,5 @@ "use client"; -import { useCallback, useState } from "react"; -import { ImageConfigDialog } from "@/components/shared/image-config-dialog"; -import { ModelConfigDialog } from "@/components/shared/model-config-dialog"; -import { VisionConfigDialog } from "@/components/shared/vision-config-dialog"; -import type { - GlobalImageGenConfig, - GlobalNewLLMConfig, - GlobalVisionLLMConfig, - ImageGenerationConfig, - NewLLMConfigPublic, - VisionLLMConfig, -} from "@/contracts/types/new-llm-config.types"; import { ModelSelector } from "./model-selector"; interface ChatHeaderProps { @@ -20,148 +8,9 @@ interface ChatHeaderProps { } export function ChatHeader({ searchSpaceId, className }: ChatHeaderProps) { - // LLM config dialog state - const [dialogOpen, setDialogOpen] = useState(false); - const [selectedConfig, setSelectedConfig] = useState< - NewLLMConfigPublic | GlobalNewLLMConfig | null - >(null); - const [isGlobal, setIsGlobal] = useState(false); - const [dialogMode, setDialogMode] = useState<"create" | "edit" | "view">("view"); - - // Image config dialog state - const [imageDialogOpen, setImageDialogOpen] = useState(false); - const [selectedImageConfig, setSelectedImageConfig] = useState< - ImageGenerationConfig | GlobalImageGenConfig | null - >(null); - const [isImageGlobal, setIsImageGlobal] = useState(false); - const [imageDialogMode, setImageDialogMode] = useState<"create" | "edit" | "view">("view"); - - // Vision config dialog state - const [visionDialogOpen, setVisionDialogOpen] = useState(false); - const [selectedVisionConfig, setSelectedVisionConfig] = useState< - VisionLLMConfig | GlobalVisionLLMConfig | null - >(null); - const [isVisionGlobal, setIsVisionGlobal] = useState(false); - const [visionDialogMode, setVisionDialogMode] = useState<"create" | "edit" | "view">("view"); - - // Default provider for create dialogs - const [defaultLLMProvider, setDefaultLLMProvider] = useState(); - const [defaultImageProvider, setDefaultImageProvider] = useState(); - const [defaultVisionProvider, setDefaultVisionProvider] = useState(); - - // LLM handlers - const handleEditLLMConfig = useCallback( - (config: NewLLMConfigPublic | GlobalNewLLMConfig, global: boolean) => { - setSelectedConfig(config); - setIsGlobal(global); - setDialogMode(global ? "view" : "edit"); - setDefaultLLMProvider(undefined); - setDialogOpen(true); - }, - [] - ); - - const handleAddNewLLM = useCallback((provider?: string) => { - setSelectedConfig(null); - setIsGlobal(false); - setDialogMode("create"); - setDefaultLLMProvider(provider); - setDialogOpen(true); - }, []); - - const handleDialogClose = useCallback((open: boolean) => { - setDialogOpen(open); - if (!open) setSelectedConfig(null); - }, []); - - // Image model handlers - const handleAddImageModel = useCallback((provider?: string) => { - setSelectedImageConfig(null); - setIsImageGlobal(false); - setImageDialogMode("create"); - setDefaultImageProvider(provider); - setImageDialogOpen(true); - }, []); - - const handleEditImageConfig = useCallback( - (config: ImageGenerationConfig | GlobalImageGenConfig, global: boolean) => { - setSelectedImageConfig(config); - setIsImageGlobal(global); - setImageDialogMode(global ? "view" : "edit"); - setDefaultImageProvider(undefined); - setImageDialogOpen(true); - }, - [] - ); - - const handleImageDialogClose = useCallback((open: boolean) => { - setImageDialogOpen(open); - if (!open) setSelectedImageConfig(null); - }, []); - - // Vision model handlers - const handleAddVisionModel = useCallback((provider?: string) => { - setSelectedVisionConfig(null); - setIsVisionGlobal(false); - setVisionDialogMode("create"); - setDefaultVisionProvider(provider); - setVisionDialogOpen(true); - }, []); - - const handleEditVisionConfig = useCallback( - (config: VisionLLMConfig | GlobalVisionLLMConfig, global: boolean) => { - setSelectedVisionConfig(config); - setIsVisionGlobal(global); - setVisionDialogMode(global ? "view" : "edit"); - setDefaultVisionProvider(undefined); - setVisionDialogOpen(true); - }, - [] - ); - - const handleVisionDialogClose = useCallback((open: boolean) => { - setVisionDialogOpen(open); - if (!open) setSelectedVisionConfig(null); - }, []); - return (
- - - - +
); } diff --git a/surfsense_web/components/settings/agent-model-manager.tsx b/surfsense_web/components/settings/agent-model-manager.tsx deleted file mode 100644 index 507a263e0..000000000 --- a/surfsense_web/components/settings/agent-model-manager.tsx +++ /dev/null @@ -1,423 +0,0 @@ -"use client"; - -import { useAtomValue } from "jotai"; -import { AlertCircle, Dot, FileText, Info, Pencil, RefreshCw, Trash2 } from "lucide-react"; -import { useMemo, useState } from "react"; -import { membersAtom, myAccessAtom } from "@/atoms/members/members-query.atoms"; -import { deleteNewLLMConfigMutationAtom } from "@/atoms/new-llm-config/new-llm-config-mutation.atoms"; -import { - globalNewLLMConfigsAtom, - newLLMConfigsAtom, -} from "@/atoms/new-llm-config/new-llm-config-query.atoms"; -import { ModelConfigDialog } from "@/components/shared/model-config-dialog"; -import { Alert, AlertDescription } from "@/components/ui/alert"; -import { - AlertDialog, - AlertDialogAction, - AlertDialogCancel, - AlertDialogContent, - AlertDialogDescription, - AlertDialogFooter, - AlertDialogHeader, - AlertDialogTitle, -} from "@/components/ui/alert-dialog"; -import { Avatar, AvatarFallback, AvatarImage } from "@/components/ui/avatar"; -import { Badge } from "@/components/ui/badge"; -import { Button } from "@/components/ui/button"; -import { Card, CardContent } from "@/components/ui/card"; -import { Separator } from "@/components/ui/separator"; -import { Skeleton } from "@/components/ui/skeleton"; -import { Spinner } from "@/components/ui/spinner"; -import { Tooltip, TooltipContent, TooltipProvider, TooltipTrigger } from "@/components/ui/tooltip"; -import type { NewLLMConfig } from "@/contracts/types/new-llm-config.types"; -import { useMediaQuery } from "@/hooks/use-media-query"; -import { getProviderIcon } from "@/lib/provider-icons"; -import { cn } from "@/lib/utils"; - -interface AgentModelManagerProps { - searchSpaceId: number; -} - -function getInitials(name: string): string { - const parts = name.trim().split(/\s+/); - if (parts.length >= 2) { - return (parts[0][0] + parts[1][0]).toUpperCase(); - } - return name.slice(0, 2).toUpperCase(); -} - -export function AgentModelManager({ searchSpaceId }: AgentModelManagerProps) { - const isDesktop = useMediaQuery("(min-width: 768px)"); - // Mutations - const { mutateAsync: deleteConfig, isPending: isDeleting } = useAtomValue( - deleteNewLLMConfigMutationAtom - ); - - // Queries - const { - data: configs, - isFetching: isLoading, - error: fetchError, - refetch: refreshConfigs, - } = useAtomValue(newLLMConfigsAtom); - const { data: globalConfigs = [] } = useAtomValue(globalNewLLMConfigsAtom); - - // Members for user resolution - const { data: members } = useAtomValue(membersAtom); - const memberMap = useMemo(() => { - const map = new Map(); - if (members) { - for (const m of members) { - map.set(m.user_id, { - name: m.user_display_name || m.user_email || "Unknown", - email: m.user_email || undefined, - avatarUrl: m.user_avatar_url || undefined, - }); - } - } - return map; - }, [members]); - - // Permissions - const { data: access } = useAtomValue(myAccessAtom); - const canCreate = - !!access && (access.is_owner || (access.permissions?.includes("llm_configs:create") ?? false)); - const canUpdate = - !!access && (access.is_owner || (access.permissions?.includes("llm_configs:update") ?? false)); - const canDelete = - !!access && (access.is_owner || (access.permissions?.includes("llm_configs:delete") ?? false)); - const isReadOnly = !canCreate && !canUpdate && !canDelete; - - // Local state - const [isDialogOpen, setIsDialogOpen] = useState(false); - const [editingConfig, setEditingConfig] = useState(null); - const [configToDelete, setConfigToDelete] = useState(null); - - const handleDelete = async () => { - if (!configToDelete) return; - try { - await deleteConfig({ id: configToDelete.id, name: configToDelete.name }); - setConfigToDelete(null); - } catch { - // Error handled by mutation state - } - }; - - const openEditDialog = (config: NewLLMConfig) => { - setEditingConfig(config); - setIsDialogOpen(true); - }; - - const openNewDialog = () => { - setEditingConfig(null); - setIsDialogOpen(true); - }; - - return ( -
- {/* Header actions */} -
- - {canCreate && ( - - )} -
- - {/* Fetch Error Alert */} - {fetchError && ( -
- - - - {fetchError?.message ?? "Failed to load configurations"} - - -
- )} - - {/* Read-only / Limited permissions notice */} - {access && !isLoading && isReadOnly && ( -
- - - -

- You have read-only access to LLM - configurations. Contact a space owner to request additional permissions. -

-
-
-
- )} - {access && !isLoading && !isReadOnly && (!canCreate || !canUpdate || !canDelete) && ( -
- - - -

- You can{" "} - {[canCreate && "create", canUpdate && "edit", canDelete && "delete"] - .filter(Boolean) - .join(" and ")}{" "} - configurations - {!canDelete && ", but cannot delete them"}. -

-
-
-
- )} - - {/* Global Configs Info */} - {(isLoading || globalConfigs.length > 0) && ( - - - - {isLoading ? ( -
- -
- ) : ( -

- - {globalConfigs.length} global {globalConfigs.length === 1 ? "model" : "models"} - {" "} - available from your administrator. -

- )} -
-
- )} - - {/* Loading Skeleton */} - {isLoading && ( -
- {["skeleton-a", "skeleton-b", "skeleton-c"].map((key) => ( - - - - - - - - ))} -
- )} - - {/* Configurations List */} - {!isLoading && ( -
- {configs?.length === 0 ? ( -
- - -

No Models Yet

-

- {canCreate - ? "Add your first model to power chat, reports, and other agent capabilities" - : "No models have been added to this space yet. Contact a space owner to add one"} -

-
-
-
- ) : ( -
- {configs?.map((config) => { - const member = config.user_id ? memberMap.get(config.user_id) : null; - - return ( -
- - - {/* Header: Icon + Name + Actions */} -
-
-
- {getProviderIcon(config.provider, { className: "size-4" })} -
-
-

- {config.name} -

- {config.description && ( -

- {config.description} -

- )} -
-
- {(canUpdate || canDelete) && ( -
- {canUpdate && ( - - - - - - Edit - - - )} - {canDelete && ( - - - - - - Delete - - - )} -
- )} -
- - {/* Feature badges */} -
- {config.citations_enabled && ( - - Citations - - )} - {!config.use_default_system_instructions && - config.system_instructions && ( - - - Custom - - )} -
- - {/* Footer: Date + Creator */} -
- -
- - {new Date(config.created_at).toLocaleDateString(undefined, { - year: "numeric", - month: "short", - day: "numeric", - })} - - {member && ( - <> - - - - -
- - {member.avatarUrl && ( - - )} - - {getInitials(member.name)} - - - - {member.name} - -
-
- - {member.email || member.name} - -
-
- - )} -
-
-
-
-
- ); - })} -
- )} -
- )} - - {/* Add/Edit Configuration Dialog */} - { - setIsDialogOpen(open); - if (!open) setEditingConfig(null); - }} - config={editingConfig} - isGlobal={false} - searchSpaceId={searchSpaceId} - mode={editingConfig ? "edit" : "create"} - /> - - {/* Delete Confirmation Dialog */} - !open && setConfigToDelete(null)} - > - - - Delete Model - - Are you sure you want to delete{" "} - {configToDelete?.name}? This - action cannot be undone. - - - - Cancel - - {isDeleting ? ( - <> - - Deleting - - ) : ( - "Delete" - )} - - - - -
- ); -} diff --git a/surfsense_web/components/settings/image-model-manager.tsx b/surfsense_web/components/settings/image-model-manager.tsx deleted file mode 100644 index 494f7aae9..000000000 --- a/surfsense_web/components/settings/image-model-manager.tsx +++ /dev/null @@ -1,489 +0,0 @@ -"use client"; - -import { useAtomValue } from "jotai"; -import { AlertCircle, Dot, Info, Pencil, RefreshCw, Trash2 } from "lucide-react"; -import { useMemo, useState } from "react"; -import { deleteImageGenConfigMutationAtom } from "@/atoms/image-gen-config/image-gen-config-mutation.atoms"; -import { - globalImageGenConfigsAtom, - imageGenConfigsAtom, -} from "@/atoms/image-gen-config/image-gen-config-query.atoms"; -import { membersAtom, myAccessAtom } from "@/atoms/members/members-query.atoms"; -import { ImageConfigDialog } from "@/components/shared/image-config-dialog"; -import { Alert, AlertDescription } from "@/components/ui/alert"; -import { - AlertDialog, - AlertDialogAction, - AlertDialogCancel, - AlertDialogContent, - AlertDialogDescription, - AlertDialogFooter, - AlertDialogHeader, - AlertDialogTitle, -} from "@/components/ui/alert-dialog"; -import { Avatar, AvatarFallback, AvatarImage } from "@/components/ui/avatar"; -import { Badge } from "@/components/ui/badge"; -import { Button } from "@/components/ui/button"; -import { Card, CardContent } from "@/components/ui/card"; -import { Separator } from "@/components/ui/separator"; -import { Skeleton } from "@/components/ui/skeleton"; -import { Spinner } from "@/components/ui/spinner"; -import { Tooltip, TooltipContent, TooltipProvider, TooltipTrigger } from "@/components/ui/tooltip"; -import type { ImageGenerationConfig } from "@/contracts/types/new-llm-config.types"; -import { useMediaQuery } from "@/hooks/use-media-query"; -import { getProviderIcon } from "@/lib/provider-icons"; -import { cn } from "@/lib/utils"; - -interface ImageModelManagerProps { - searchSpaceId: number; -} - -function getInitials(name: string): string { - const parts = name.trim().split(/\s+/); - if (parts.length >= 2) { - return (parts[0][0] + parts[1][0]).toUpperCase(); - } - return name.slice(0, 2).toUpperCase(); -} - -export function ImageModelManager({ searchSpaceId }: ImageModelManagerProps) { - const isDesktop = useMediaQuery("(min-width: 768px)"); - - const { - mutateAsync: deleteConfig, - isPending: isDeleting, - error: deleteError, - } = useAtomValue(deleteImageGenConfigMutationAtom); - - const { - data: userConfigs, - isFetching: configsLoading, - error: fetchError, - refetch: refreshConfigs, - } = useAtomValue(imageGenConfigsAtom); - const { data: globalConfigs = [], isFetching: globalLoading } = - useAtomValue(globalImageGenConfigsAtom); - - const { data: members } = useAtomValue(membersAtom); - const memberMap = useMemo(() => { - const map = new Map(); - if (members) { - for (const m of members) { - map.set(m.user_id, { - name: m.user_display_name || m.user_email || "Unknown", - email: m.user_email || undefined, - avatarUrl: m.user_avatar_url || undefined, - }); - } - } - return map; - }, [members]); - - const { data: access } = useAtomValue(myAccessAtom); - const canCreate = - !!access && - (access.is_owner || (access.permissions?.includes("image_generations:create") ?? false)); - const canDelete = - !!access && - (access.is_owner || (access.permissions?.includes("image_generations:delete") ?? false)); - const canUpdate = canCreate; - const isReadOnly = !canCreate && !canDelete; - - const [isDialogOpen, setIsDialogOpen] = useState(false); - const [editingConfig, setEditingConfig] = useState(null); - const [configToDelete, setConfigToDelete] = useState(null); - - const isLoading = configsLoading || globalLoading; - const errors = [deleteError, fetchError].filter(Boolean) as Error[]; - - const openEditDialog = (config: ImageGenerationConfig) => { - setEditingConfig(config); - setIsDialogOpen(true); - }; - - const openNewDialog = () => { - setEditingConfig(null); - setIsDialogOpen(true); - }; - - const handleDelete = async () => { - if (!configToDelete) return; - try { - await deleteConfig({ id: configToDelete.id, name: configToDelete.name }); - setConfigToDelete(null); - } catch { - // Error handled by mutation - } - }; - - return ( -
- {/* Header actions */} -
- - {canCreate && ( - - )} -
- - {/* Errors */} - {errors.map((err) => ( -
- - - {err?.message} - -
- ))} - - {/* Read-only / Limited permissions notice */} - {access && !isLoading && isReadOnly && ( -
- - - -

- You have read-only access to image generation - configurations. Contact a space owner to request additional permissions. -

-
-
-
- )} - {access && !isLoading && !isReadOnly && (!canCreate || !canDelete) && ( -
- - - -

- You can{" "} - {[canCreate && "create and edit", canDelete && "delete"] - .filter(Boolean) - .join(" and ")}{" "} - image model configurations - {!canDelete && ", but cannot delete them"}. -

-
-
-
- )} - - {/* Global info */} - {(isLoading || - globalConfigs.filter((g) => !("is_auto_mode" in g && g.is_auto_mode)).length > 0) && ( - - - - {isLoading ? ( -
- -
- ) : ( -

- - {globalConfigs.filter((g) => !("is_auto_mode" in g && g.is_auto_mode)).length}{" "} - global image{" "} - {globalConfigs.filter((g) => !("is_auto_mode" in g && g.is_auto_mode)).length === - 1 - ? "model" - : "models"} - {" "} - available from your administrator. {(() => { - const nonAuto = globalConfigs.filter( - (g) => !("is_auto_mode" in g && g.is_auto_mode) - ); - const premium = nonAuto.filter( - (g) => - "billing_tier" in g && - (g as { billing_tier?: string }).billing_tier === "premium" - ).length; - const free = nonAuto.length - premium; - if (premium > 0 && free > 0) { - return `${premium} premium, ${free} free.`; - } - if (premium > 0) { - return `All ${premium} premium — debits your shared credit pool.`; - } - return `All ${free} free.`; - })()} -

- )} -
-
- )} - - {/* Global Image Models — read-only cards with per-model Free/Premium - badges. Mirrors the badge palette used by the chat role selector - (`llm-role-manager.tsx`) so the meaning is consistent across - every model-configuration surface (chat / image / vision). */} - {!isLoading && - globalConfigs.filter((g) => !("is_auto_mode" in g && g.is_auto_mode)).length > 0 && ( -
-
- {globalConfigs - .filter((g) => !("is_auto_mode" in g && g.is_auto_mode)) - .map((cfg) => { - const billingTier = - ("billing_tier" in cfg && - typeof (cfg as { billing_tier?: string }).billing_tier === "string" && - (cfg as { billing_tier?: string }).billing_tier) || - "free"; - const isPremium = billingTier === "premium"; - return ( - - -
-
- {getProviderIcon(cfg.provider, { className: "size-4" })} -
-
-

- {cfg.name} -

- {isPremium ? ( - - Premium - - ) : ( - - Free - - )} -
-
- {cfg.description && ( -

- {cfg.description} -

- )} -
- -
- - {cfg.model_name} - -
-
-
-
- ); - })} -
-
- )} - - {/* Loading Skeleton */} - {isLoading && ( -
-
-
- {["skeleton-a", "skeleton-b", "skeleton-c"].map((key) => ( - - - - - - - - ))} -
-
-
- )} - - {/* User Configs */} - {!isLoading && ( -
- {(userConfigs?.length ?? 0) === 0 ? ( - - -

No Image Models Yet

-

- {canCreate - ? "Add your own image generation model (DALL-E 3, GPT Image 1, etc.)" - : "No image models have been added to this space yet. Contact a space owner to add one."} -

-
-
- ) : ( -
- {userConfigs?.map((config) => { - const member = config.user_id ? memberMap.get(config.user_id) : null; - - return ( -
- - - {/* Header: Icon + Name + Actions */} -
-
-
- {getProviderIcon(config.provider, { className: "size-4" })} -
-
-

- {config.name} -

- {config.description && ( -

- {config.description} -

- )} -
-
- {(canUpdate || canDelete) && ( -
- {canUpdate && ( - - - - - - Edit - - - )} - {canDelete && ( - - - - - - Delete - - - )} -
- )} -
- - {/* Footer: Date + Creator */} -
- -
- - {new Date(config.created_at).toLocaleDateString(undefined, { - year: "numeric", - month: "short", - day: "numeric", - })} - - {member && ( - <> - - - - -
- - {member.avatarUrl && ( - - )} - - {getInitials(member.name)} - - - - {member.name} - -
-
- - {member.email || member.name} - -
-
- - )} -
-
-
-
-
- ); - })} -
- )} -
- )} - - {/* Create/Edit Dialog — shared component */} - { - setIsDialogOpen(open); - if (!open) setEditingConfig(null); - }} - config={editingConfig} - isGlobal={false} - searchSpaceId={searchSpaceId} - mode={editingConfig ? "edit" : "create"} - /> - - {/* Delete Confirmation */} - !open && setConfigToDelete(null)} - > - - - Delete Image Model - - Are you sure you want to delete{" "} - {configToDelete?.name}? - - - - Cancel - - Delete - {isDeleting && } - - - - -
- ); -} diff --git a/surfsense_web/components/settings/llm-role-manager.tsx b/surfsense_web/components/settings/llm-role-manager.tsx deleted file mode 100644 index 547675927..000000000 --- a/surfsense_web/components/settings/llm-role-manager.tsx +++ /dev/null @@ -1,443 +0,0 @@ -"use client"; - -import { useAtomValue } from "jotai"; -import { - AlertCircle, - Bot, - CircleCheck, - CircleDashed, - FileText, - ImageIcon, - RefreshCw, - ScanEye, -} from "lucide-react"; -import { useCallback, useEffect, useState } from "react"; -import { toast } from "sonner"; -import { - globalImageGenConfigsAtom, - imageGenConfigsAtom, -} from "@/atoms/image-gen-config/image-gen-config-query.atoms"; -import { updateLLMPreferencesMutationAtom } from "@/atoms/new-llm-config/new-llm-config-mutation.atoms"; -import { - globalNewLLMConfigsAtom, - llmPreferencesAtom, - newLLMConfigsAtom, -} from "@/atoms/new-llm-config/new-llm-config-query.atoms"; -import { - globalVisionLLMConfigsAtom, - visionLLMConfigsAtom, -} from "@/atoms/vision-llm-config/vision-llm-config-query.atoms"; -import { Alert, AlertDescription } from "@/components/ui/alert"; -import { Badge } from "@/components/ui/badge"; -import { Button } from "@/components/ui/button"; -import { Card, CardContent } from "@/components/ui/card"; -import { Label } from "@/components/ui/label"; -import { - Select, - SelectContent, - SelectGroup, - SelectItem, - SelectLabel, - SelectTrigger, - SelectValue, -} from "@/components/ui/select"; -import { Skeleton } from "@/components/ui/skeleton"; -import { Spinner } from "@/components/ui/spinner"; -import { cn } from "@/lib/utils"; - -const ROLE_DESCRIPTIONS = { - agent: { - icon: Bot, - title: "Chat model", - description: "Primary model for chat interactions and agent operations", - color: "text-muted-foreground", - bgColor: "bg-muted", - prefKey: "agent_llm_id" as const, - configType: "llm" as const, - }, - image_generation: { - icon: ImageIcon, - title: "Image Generation Model", - description: "Model used for AI image generation (DALL-E, GPT Image, etc.)", - color: "text-muted-foreground", - bgColor: "bg-muted", - prefKey: "image_generation_config_id" as const, - configType: "image" as const, - }, - vision: { - icon: ScanEye, - title: "Vision LLM", - description: "Vision-capable model for screenshot analysis and context extraction", - color: "text-muted-foreground", - bgColor: "bg-muted", - prefKey: "vision_llm_config_id" as const, - configType: "vision" as const, - }, -}; - -interface LLMRoleManagerProps { - searchSpaceId: number; -} - -export function LLMRoleManager({ searchSpaceId }: LLMRoleManagerProps) { - // LLM configs - const { - data: newLLMConfigs = [], - isFetching: configsLoading, - error: configsError, - refetch: refreshConfigs, - } = useAtomValue(newLLMConfigsAtom); - const { - data: globalConfigs = [], - isFetching: globalConfigsLoading, - error: globalConfigsError, - } = useAtomValue(globalNewLLMConfigsAtom); - - // Image gen configs - const { - data: userImageConfigs = [], - isFetching: imageConfigsLoading, - error: imageConfigsError, - } = useAtomValue(imageGenConfigsAtom); - const { - data: globalImageConfigs = [], - isFetching: globalImageConfigsLoading, - error: globalImageConfigsError, - } = useAtomValue(globalImageGenConfigsAtom); - - // Vision LLM configs - const { - data: userVisionConfigs = [], - isFetching: visionConfigsLoading, - error: visionConfigsError, - } = useAtomValue(visionLLMConfigsAtom); - const { - data: globalVisionConfigs = [], - isFetching: globalVisionConfigsLoading, - error: globalVisionConfigsError, - } = useAtomValue(globalVisionLLMConfigsAtom); - - // Preferences - const { - data: preferences = {}, - isFetching: preferencesLoading, - error: preferencesError, - } = useAtomValue(llmPreferencesAtom); - - const { mutateAsync: updatePreferences } = useAtomValue(updateLLMPreferencesMutationAtom); - - const [assignments, setAssignments] = useState>(() => ({ - agent_llm_id: preferences.agent_llm_id ?? null, - image_generation_config_id: preferences.image_generation_config_id ?? null, - vision_llm_config_id: preferences.vision_llm_config_id ?? null, - })); - - // Sync local state when preferences load/change. Without this, the selects - // stay on their initial (often empty) value while the query is in flight, - // so a saved assignment — including Auto mode (id 0) — never appears. - useEffect(() => { - setAssignments({ - agent_llm_id: preferences.agent_llm_id ?? null, - image_generation_config_id: preferences.image_generation_config_id ?? null, - vision_llm_config_id: preferences.vision_llm_config_id ?? null, - }); - }, [ - preferences.agent_llm_id, - preferences.image_generation_config_id, - preferences.vision_llm_config_id, - ]); - - const [savingRole, setSavingRole] = useState(null); - - const handleRoleAssignment = useCallback( - async (prefKey: string, configId: string) => { - // "unassigned" clears the role (null). Every other option — including - // Auto mode, whose config id is 0 — must be sent as-is. Using a falsy - // check here (e.g. `value || undefined`) would drop id 0 and silently - // fail to persist Auto mode. - const value = configId === "unassigned" ? null : Number(configId); - - setAssignments((prev) => ({ ...prev, [prefKey]: value })); - setSavingRole(prefKey); - - try { - await updatePreferences({ - search_space_id: searchSpaceId, - data: { [prefKey]: value }, - }); - toast.success("Role assignment updated"); - } finally { - setSavingRole(null); - } - }, - [updatePreferences, searchSpaceId] - ); - - // Combine global and custom LLM configs - const allLLMConfigs = [ - ...globalConfigs.map((config) => ({ ...config, is_global: true })), - ...newLLMConfigs.filter((config) => config.id && config.id.toString().trim() !== ""), - ]; - - // Combine global and custom image gen configs - const allImageConfigs = [ - ...globalImageConfigs.map((config) => ({ ...config, is_global: true })), - ...(userImageConfigs ?? []).filter((config) => config.id && config.id.toString().trim() !== ""), - ]; - - // Combine global and custom vision LLM configs - const allVisionConfigs = [ - ...globalVisionConfigs.map((config) => ({ ...config, is_global: true })), - ...(userVisionConfigs ?? []).filter( - (config) => config.id && config.id.toString().trim() !== "" - ), - ]; - - const isLoading = - configsLoading || - preferencesLoading || - globalConfigsLoading || - imageConfigsLoading || - globalImageConfigsLoading || - visionConfigsLoading || - globalVisionConfigsLoading; - const hasError = - configsError || - preferencesError || - globalConfigsError || - imageConfigsError || - globalImageConfigsError || - visionConfigsError || - globalVisionConfigsError; - const hasAnyConfigs = allLLMConfigs.length > 0 || allImageConfigs.length > 0; - - return ( -
- {/* Header actions */} -
- -
- - {/* Error Alert */} - {hasError && ( -
- - - - {(configsError?.message ?? "Failed to load LLM configurations") || - (preferencesError?.message ?? "Failed to load preferences") || - (globalConfigsError?.message ?? "Failed to load global configurations")} - - -
- )} - - {/* Loading Skeleton */} - {isLoading && ( -
- {["skeleton-a", "skeleton-b", "skeleton-c"].map((key) => ( - - - - - - - - ))} -
- )} - - {/* No configs warning */} - {!isLoading && !hasError && !hasAnyConfigs && ( - - - - No configurations found. Please add at least one LLM provider or image model in the - respective settings tabs before assigning roles. - - - )} - - {/* Role Assignment Cards */} - {!isLoading && !hasError && hasAnyConfigs && ( -
- {Object.entries(ROLE_DESCRIPTIONS).map(([key, role]) => { - const IconComponent = role.icon; - const currentAssignment = assignments[role.prefKey as keyof typeof assignments]; - - // Pick the right config lists based on role type - const roleGlobalConfigs = - role.configType === "image" - ? globalImageConfigs - : role.configType === "vision" - ? globalVisionConfigs - : globalConfigs; - const roleUserConfigs = - role.configType === "image" - ? (userImageConfigs ?? []).filter((c) => c.id && c.id.toString().trim() !== "") - : role.configType === "vision" - ? (userVisionConfigs ?? []).filter((c) => c.id && c.id.toString().trim() !== "") - : newLLMConfigs.filter((c) => c.id && c.id.toString().trim() !== ""); - const roleAllConfigs = - role.configType === "image" - ? allImageConfigs - : role.configType === "vision" - ? allVisionConfigs - : allLLMConfigs; - - const assignedConfig = roleAllConfigs.find((config) => config.id === currentAssignment); - const isAssigned = !!assignedConfig; - - return ( -
- - - {/* Role Header */} -
-
-
- -
-
-

{role.title}

-

- {role.description} -

-
-
- {savingRole === role.prefKey ? ( - - ) : isAssigned ? ( - - ) : ( - - )} -
- - {/* Selector */} -
- - -
-
-
-
- ); - })} -
- )} -
- ); -} diff --git a/surfsense_web/components/settings/vision-model-manager.tsx b/surfsense_web/components/settings/vision-model-manager.tsx deleted file mode 100644 index 31578b4f1..000000000 --- a/surfsense_web/components/settings/vision-model-manager.tsx +++ /dev/null @@ -1,486 +0,0 @@ -"use client"; - -import { useAtomValue } from "jotai"; -import { AlertCircle, Dot, Info, Pencil, RefreshCw, Trash2 } from "lucide-react"; -import { useMemo, useState } from "react"; -import { membersAtom, myAccessAtom } from "@/atoms/members/members-query.atoms"; -import { deleteVisionLLMConfigMutationAtom } from "@/atoms/vision-llm-config/vision-llm-config-mutation.atoms"; -import { - globalVisionLLMConfigsAtom, - visionLLMConfigsAtom, -} from "@/atoms/vision-llm-config/vision-llm-config-query.atoms"; -import { VisionConfigDialog } from "@/components/shared/vision-config-dialog"; -import { Alert, AlertDescription } from "@/components/ui/alert"; -import { - AlertDialog, - AlertDialogAction, - AlertDialogCancel, - AlertDialogContent, - AlertDialogDescription, - AlertDialogFooter, - AlertDialogHeader, - AlertDialogTitle, -} from "@/components/ui/alert-dialog"; -import { Avatar, AvatarFallback, AvatarImage } from "@/components/ui/avatar"; -import { Badge } from "@/components/ui/badge"; -import { Button } from "@/components/ui/button"; -import { Card, CardContent } from "@/components/ui/card"; -import { Separator } from "@/components/ui/separator"; -import { Skeleton } from "@/components/ui/skeleton"; -import { Spinner } from "@/components/ui/spinner"; -import { Tooltip, TooltipContent, TooltipProvider, TooltipTrigger } from "@/components/ui/tooltip"; -import type { VisionLLMConfig } from "@/contracts/types/new-llm-config.types"; -import { useMediaQuery } from "@/hooks/use-media-query"; -import { getProviderIcon } from "@/lib/provider-icons"; -import { cn } from "@/lib/utils"; - -interface VisionModelManagerProps { - searchSpaceId: number; -} - -function getInitials(name: string): string { - const parts = name.trim().split(/\s+/); - if (parts.length >= 2) { - return (parts[0][0] + parts[1][0]).toUpperCase(); - } - return name.slice(0, 2).toUpperCase(); -} - -export function VisionModelManager({ searchSpaceId }: VisionModelManagerProps) { - const isDesktop = useMediaQuery("(min-width: 768px)"); - - const { - mutateAsync: deleteConfig, - isPending: isDeleting, - error: deleteError, - } = useAtomValue(deleteVisionLLMConfigMutationAtom); - - const { - data: userConfigs, - isFetching: configsLoading, - error: fetchError, - refetch: refreshConfigs, - } = useAtomValue(visionLLMConfigsAtom); - const { data: globalConfigs = [], isFetching: globalLoading } = useAtomValue( - globalVisionLLMConfigsAtom - ); - - const { data: members } = useAtomValue(membersAtom); - const memberMap = useMemo(() => { - const map = new Map(); - if (members) { - for (const m of members) { - map.set(m.user_id, { - name: m.user_display_name || m.user_email || "Unknown", - email: m.user_email || undefined, - avatarUrl: m.user_avatar_url || undefined, - }); - } - } - return map; - }, [members]); - - const { data: access } = useAtomValue(myAccessAtom); - const canCreate = useMemo(() => { - if (!access) return false; - if (access.is_owner) return true; - return access.permissions?.includes("vision_configs:create") ?? false; - }, [access]); - const canDelete = useMemo(() => { - if (!access) return false; - if (access.is_owner) return true; - return access.permissions?.includes("vision_configs:delete") ?? false; - }, [access]); - const canUpdate = canCreate; - const isReadOnly = !canCreate && !canDelete; - - const [isDialogOpen, setIsDialogOpen] = useState(false); - const [editingConfig, setEditingConfig] = useState(null); - const [configToDelete, setConfigToDelete] = useState(null); - - const isLoading = configsLoading || globalLoading; - const errors = [deleteError, fetchError].filter(Boolean) as Error[]; - - const openEditDialog = (config: VisionLLMConfig) => { - setEditingConfig(config); - setIsDialogOpen(true); - }; - - const openNewDialog = () => { - setEditingConfig(null); - setIsDialogOpen(true); - }; - - const handleDelete = async () => { - if (!configToDelete) return; - try { - await deleteConfig({ id: configToDelete.id, name: configToDelete.name }); - setConfigToDelete(null); - } catch { - // Error handled by mutation - } - }; - - return ( -
-
- - {canCreate && ( - - )} -
- - {errors.map((err) => ( -
- - - {err?.message} - -
- ))} - - {access && !isLoading && isReadOnly && ( -
- - - -

- You have read-only access to vision model - configurations. Contact a space owner to request additional permissions. -

-
-
-
- )} - {access && !isLoading && !isReadOnly && (!canCreate || !canDelete) && ( -
- - - -

- You can{" "} - {[canCreate && "create and edit", canDelete && "delete"] - .filter(Boolean) - .join(" and ")}{" "} - vision model configurations - {!canDelete && ", but cannot delete them"}. -

-
-
-
- )} - - {(isLoading || - globalConfigs.filter((g) => !("is_auto_mode" in g && g.is_auto_mode)).length > 0) && ( - - - - {isLoading ? ( -
- -
- ) : ( -

- - {globalConfigs.filter((g) => !("is_auto_mode" in g && g.is_auto_mode)).length}{" "} - global vision{" "} - {globalConfigs.filter((g) => !("is_auto_mode" in g && g.is_auto_mode)).length === - 1 - ? "model" - : "models"} - {" "} - available from your administrator. {(() => { - const nonAuto = globalConfigs.filter( - (g) => !("is_auto_mode" in g && g.is_auto_mode) - ); - const premium = nonAuto.filter( - (g) => - "billing_tier" in g && - (g as { billing_tier?: string }).billing_tier === "premium" - ).length; - const free = nonAuto.length - premium; - if (premium > 0 && free > 0) { - return `${premium} premium, ${free} free.`; - } - if (premium > 0) { - return `All ${premium} premium — debits your shared credit pool.`; - } - return `All ${free} free.`; - })()} -

- )} -
-
- )} - - {/* Global Vision Models — read-only cards with per-model Free/Premium - badges. Mirrors the badge palette used by the chat role selector - (`llm-role-manager.tsx`) so the meaning is consistent across - every model-configuration surface (chat / image / vision). */} - {!isLoading && - globalConfigs.filter((g) => !("is_auto_mode" in g && g.is_auto_mode)).length > 0 && ( -
-
- {globalConfigs - .filter((g) => !("is_auto_mode" in g && g.is_auto_mode)) - .map((cfg) => { - const billingTier = - ("billing_tier" in cfg && - typeof (cfg as { billing_tier?: string }).billing_tier === "string" && - (cfg as { billing_tier?: string }).billing_tier) || - "free"; - const isPremium = billingTier === "premium"; - return ( - - -
-
- {getProviderIcon(cfg.provider, { className: "size-4" })} -
-
-

- {cfg.name} -

- {isPremium ? ( - - Premium - - ) : ( - - Free - - )} -
-
- {cfg.description && ( -

- {cfg.description} -

- )} -
- -
- - {cfg.model_name} - -
-
-
-
- ); - })} -
-
- )} - - {isLoading && ( -
-
-
- {["skeleton-a", "skeleton-b", "skeleton-c"].map((key) => ( - - - - - - - - ))} -
-
-
- )} - - {!isLoading && ( -
- {(userConfigs?.length ?? 0) === 0 ? ( - - -

No Vision Models Yet

-

- {canCreate - ? "Add your own vision-capable model (GPT-4o, Claude, Gemini, etc.)" - : "No vision models have been added to this space yet. Contact a space owner to add one."} -

-
-
- ) : ( -
- {userConfigs?.map((config) => { - const member = config.user_id ? memberMap.get(config.user_id) : null; - - return ( -
- - - {/* Header: Icon + Name + Actions */} -
-
-
- {getProviderIcon(config.provider, { className: "size-4" })} -
-
-

- {config.name} -

- {config.description && ( -

- {config.description} -

- )} -
-
- {(canUpdate || canDelete) && ( -
- {canUpdate && ( - - - - - - Edit - - - )} - {canDelete && ( - - - - - - Delete - - - )} -
- )} -
- - {/* Footer: Date + Creator */} -
- -
- - {new Date(config.created_at).toLocaleDateString(undefined, { - year: "numeric", - month: "short", - day: "numeric", - })} - - {member && ( - <> - - - - -
- - {member.avatarUrl && ( - - )} - - {getInitials(member.name)} - - - - {member.name} - -
-
- - {member.email || member.name} - -
-
- - )} -
-
-
-
-
- ); - })} -
- )} -
- )} - - { - setIsDialogOpen(open); - if (!open) setEditingConfig(null); - }} - config={editingConfig} - isGlobal={false} - searchSpaceId={searchSpaceId} - mode={editingConfig ? "edit" : "create"} - /> - - !open && setConfigToDelete(null)} - > - - - Delete Vision Model - - Are you sure you want to delete{" "} - {configToDelete?.name}? - - - - Cancel - - Delete - {isDeleting && } - - - - -
- ); -} diff --git a/surfsense_web/components/shared/image-config-dialog.tsx b/surfsense_web/components/shared/image-config-dialog.tsx deleted file mode 100644 index 36d16081a..000000000 --- a/surfsense_web/components/shared/image-config-dialog.tsx +++ /dev/null @@ -1,456 +0,0 @@ -"use client"; - -import { useAtomValue } from "jotai"; -import { AlertCircle, Check, ChevronsUpDown } from "lucide-react"; -import { useCallback, useEffect, useMemo, useRef, useState } from "react"; -import { toast } from "sonner"; -import { - createImageGenConfigMutationAtom, - updateImageGenConfigMutationAtom, -} from "@/atoms/image-gen-config/image-gen-config-mutation.atoms"; -import { updateLLMPreferencesMutationAtom } from "@/atoms/new-llm-config/new-llm-config-mutation.atoms"; -import { Alert, AlertDescription } from "@/components/ui/alert"; -import { Badge } from "@/components/ui/badge"; -import { Button } from "@/components/ui/button"; -import { - Command, - CommandEmpty, - CommandGroup, - CommandInput, - CommandItem, - CommandList, -} from "@/components/ui/command"; -import { Dialog, DialogContent, DialogTitle } from "@/components/ui/dialog"; -import { Input } from "@/components/ui/input"; -import { Label } from "@/components/ui/label"; -import { Popover, PopoverContent, PopoverTrigger } from "@/components/ui/popover"; -import { - Select, - SelectContent, - SelectItem, - SelectTrigger, - SelectValue, -} from "@/components/ui/select"; -import { Separator } from "@/components/ui/separator"; -import { Spinner } from "@/components/ui/spinner"; -import { IMAGE_GEN_MODELS, IMAGE_GEN_PROVIDERS } from "@/contracts/enums/image-gen-providers"; -import type { - GlobalImageGenConfig, - ImageGenerationConfig, - ImageGenProvider, -} from "@/contracts/types/new-llm-config.types"; -import { cn } from "@/lib/utils"; - -interface ImageConfigDialogProps { - open: boolean; - onOpenChange: (open: boolean) => void; - config: ImageGenerationConfig | GlobalImageGenConfig | null; - isGlobal: boolean; - searchSpaceId: number; - mode: "create" | "edit" | "view"; - defaultProvider?: string; -} - -const INITIAL_FORM = { - name: "", - description: "", - provider: "", - model_name: "", - api_key: "", - api_base: "", - api_version: "", -}; - -export function ImageConfigDialog({ - open, - onOpenChange, - config, - isGlobal, - searchSpaceId, - mode, - defaultProvider, -}: ImageConfigDialogProps) { - const [isSubmitting, setIsSubmitting] = useState(false); - const [formData, setFormData] = useState(INITIAL_FORM); - const [modelComboboxOpen, setModelComboboxOpen] = useState(false); - const [scrollPos, setScrollPos] = useState<"top" | "middle" | "bottom">("top"); - const scrollRef = useRef(null); - - useEffect(() => { - if (open) { - if (mode === "edit" && config && !isGlobal) { - setFormData({ - name: config.name || "", - description: config.description || "", - provider: config.provider || "", - model_name: config.model_name || "", - api_key: (config as ImageGenerationConfig).api_key || "", - api_base: config.api_base || "", - api_version: config.api_version || "", - }); - } else if (mode === "create") { - setFormData({ ...INITIAL_FORM, provider: defaultProvider ?? "" }); - } - setScrollPos("top"); - } - }, [open, mode, config, isGlobal, defaultProvider]); - - const { mutateAsync: createConfig } = useAtomValue(createImageGenConfigMutationAtom); - const { mutateAsync: updateConfig } = useAtomValue(updateImageGenConfigMutationAtom); - const { mutateAsync: updatePreferences } = useAtomValue(updateLLMPreferencesMutationAtom); - - const handleScroll = useCallback((e: React.UIEvent) => { - const el = e.currentTarget; - const atTop = el.scrollTop <= 2; - const atBottom = el.scrollHeight - el.scrollTop - el.clientHeight <= 2; - setScrollPos(atTop ? "top" : atBottom ? "bottom" : "middle"); - }, []); - - const suggestedModels = useMemo(() => { - if (!formData.provider) return []; - return IMAGE_GEN_MODELS.filter((m) => m.provider === formData.provider); - }, [formData.provider]); - - const getTitle = () => { - if (mode === "create") return "Add Image Model"; - if (isGlobal) return "View Global Image Model"; - return "Edit Image Model"; - }; - - const getSubtitle = () => { - if (mode === "create") return "Set up a new image generation provider"; - if (isGlobal) return "Read-only global configuration"; - return "Update your image model settings"; - }; - - const handleSubmit = useCallback(async () => { - setIsSubmitting(true); - try { - if (mode === "create") { - const result = await createConfig({ - name: formData.name, - provider: formData.provider as ImageGenProvider, - model_name: formData.model_name, - api_key: formData.api_key, - api_base: formData.api_base || undefined, - api_version: formData.api_version || undefined, - description: formData.description || undefined, - search_space_id: searchSpaceId, - }); - if (result?.id) { - await updatePreferences({ - search_space_id: searchSpaceId, - data: { image_generation_config_id: result.id }, - }); - } - onOpenChange(false); - } else if (!isGlobal && config) { - await updateConfig({ - id: config.id, - data: { - name: formData.name, - description: formData.description || undefined, - provider: formData.provider as ImageGenProvider, - model_name: formData.model_name, - api_key: formData.api_key, - api_base: formData.api_base || undefined, - api_version: formData.api_version || undefined, - }, - }); - onOpenChange(false); - } - } catch (error) { - console.error("Failed to save image config:", error); - toast.error("Failed to save image model"); - } finally { - setIsSubmitting(false); - } - }, [ - mode, - isGlobal, - config, - formData, - searchSpaceId, - createConfig, - updateConfig, - updatePreferences, - onOpenChange, - ]); - - const handleUseGlobalConfig = useCallback(async () => { - if (!config || !isGlobal) return; - setIsSubmitting(true); - try { - await updatePreferences({ - search_space_id: searchSpaceId, - data: { image_generation_config_id: config.id }, - }); - toast.success(`Now using ${config.name}`); - onOpenChange(false); - } catch (error) { - console.error("Failed to set image model:", error); - toast.error("Failed to set image model"); - } finally { - setIsSubmitting(false); - } - }, [config, isGlobal, searchSpaceId, updatePreferences, onOpenChange]); - - const isFormValid = formData.name && formData.provider && formData.model_name && formData.api_key; - const selectedProvider = IMAGE_GEN_PROVIDERS.find((p) => p.value === formData.provider); - - return ( - - e.preventDefault()} - > - {getTitle()} - - {/* Header */} -
-
-
-

{getTitle()}

- {isGlobal && mode !== "create" && ( - - Global - - )} -
-

{getSubtitle()}

- {config && mode !== "create" && ( -

{config.model_name}

- )} -
-
- - {/* Scrollable content */} -
- {isGlobal && config && ( - <> - - - - Global configurations are read-only. To customize, create a new model. - - -
-
-
-
- Name -
-

{config.name}

-
- {config.description && ( -
-
- Description -
-

{config.description}

-
- )} -
- -
-
-
- Provider -
-

{config.provider}

-
-
-
- Model -
-

{config.model_name}

-
-
-
- - )} - - {(mode === "create" || (mode === "edit" && !isGlobal)) && ( -
-
- - setFormData((p) => ({ ...p, name: e.target.value }))} - /> -
- -
- - setFormData((p) => ({ ...p, description: e.target.value }))} - /> -
- - - -
- - -
- -
- - {suggestedModels.length > 0 ? ( - - - - - - - setFormData((p) => ({ ...p, model_name: val }))} - /> - - - - Type a custom model name - - - - {suggestedModels.map((m) => ( - { - setFormData((p) => ({ ...p, model_name: m.value })); - setModelComboboxOpen(false); - }} - > - - {m.value} - - {m.label} - - - ))} - - - - - - ) : ( - setFormData((p) => ({ ...p, model_name: e.target.value }))} - /> - )} -
- -
- - setFormData((p) => ({ ...p, api_key: e.target.value }))} - /> -
- -
- - setFormData((p) => ({ ...p, api_base: e.target.value }))} - /> -
- - {formData.provider === "AZURE_OPENAI" && ( -
- - setFormData((p) => ({ ...p, api_version: e.target.value }))} - /> -
- )} -
- )} -
- - {/* Fixed footer */} -
- - {mode === "create" || (mode === "edit" && !isGlobal) ? ( - - ) : isGlobal && config ? ( - - ) : null} -
-
-
- ); -} diff --git a/surfsense_web/components/shared/llm-config-form.tsx b/surfsense_web/components/shared/llm-config-form.tsx deleted file mode 100644 index 06de4129b..000000000 --- a/surfsense_web/components/shared/llm-config-form.tsx +++ /dev/null @@ -1,527 +0,0 @@ -"use client"; - -import { zodResolver } from "@hookform/resolvers/zod"; -import { useAtomValue } from "jotai"; -import { Check, ChevronDown, ChevronsUpDown } from "lucide-react"; -import { useEffect, useMemo, useState } from "react"; -import { type Resolver, useForm } from "react-hook-form"; -import { z } from "zod"; -import { - defaultSystemInstructionsAtom, - modelListAtom, -} from "@/atoms/new-llm-config/new-llm-config-query.atoms"; -import { Badge } from "@/components/ui/badge"; -import { Button } from "@/components/ui/button"; -import { Collapsible, CollapsibleContent, CollapsibleTrigger } from "@/components/ui/collapsible"; -import { - Command, - CommandEmpty, - CommandGroup, - CommandInput, - CommandItem, - CommandList, -} from "@/components/ui/command"; -import { - Form, - FormControl, - FormDescription, - FormField, - FormItem, - FormLabel, - FormMessage, -} from "@/components/ui/form"; -import { Input } from "@/components/ui/input"; -import { Popover, PopoverContent, PopoverTrigger } from "@/components/ui/popover"; -import { - Select, - SelectContent, - SelectItem, - SelectTrigger, - SelectValue, -} from "@/components/ui/select"; -import { Separator } from "@/components/ui/separator"; -import { Switch } from "@/components/ui/switch"; -import { Textarea } from "@/components/ui/textarea"; -import { LLM_PROVIDERS } from "@/contracts/enums/llm-providers"; -import type { CreateNewLLMConfigRequest } from "@/contracts/types/new-llm-config.types"; -import { cn } from "@/lib/utils"; -import InferenceParamsEditor from "../inference-params-editor"; - -// Form schema with zod -const formSchema = z.object({ - name: z.string().min(1, "Name is required").max(100), - description: z.string().max(500).optional().nullable(), - provider: z.string().min(1, "Provider is required"), - custom_provider: z.string().max(100).optional().nullable(), - model_name: z.string().min(1, "Model name is required").max(100), - api_key: z.string().min(1, "API key is required"), - api_base: z.string().max(500).optional().nullable(), - litellm_params: z.record(z.string(), z.any()).optional().nullable(), - system_instructions: z.string().default(""), - use_default_system_instructions: z.boolean().default(true), - citations_enabled: z.boolean().default(true), - search_space_id: z.number(), -}); - -type FormValues = z.infer; - -export type LLMConfigFormData = CreateNewLLMConfigRequest; - -interface LLMConfigFormProps { - initialData?: Partial; - searchSpaceId: number; - onSubmit: (data: LLMConfigFormData) => Promise; - mode?: "create" | "edit"; - showAdvanced?: boolean; - formId?: string; -} - -export function LLMConfigForm({ - initialData, - searchSpaceId, - onSubmit, - mode = "create", - showAdvanced = true, - formId, -}: LLMConfigFormProps) { - const { data: defaultInstructions, isSuccess: defaultInstructionsLoaded } = useAtomValue( - defaultSystemInstructionsAtom - ); - const { data: dynamicModels } = useAtomValue(modelListAtom); - const [modelComboboxOpen, setModelComboboxOpen] = useState(false); - const [advancedOpen, setAdvancedOpen] = useState(false); - const [systemInstructionsOpen, setSystemInstructionsOpen] = useState(false); - - const form = useForm({ - resolver: zodResolver(formSchema) as Resolver, - defaultValues: { - name: initialData?.name ?? "", - description: initialData?.description ?? "", - provider: initialData?.provider ?? "", - custom_provider: initialData?.custom_provider ?? "", - model_name: initialData?.model_name ?? "", - api_key: initialData?.api_key ?? "", - api_base: initialData?.api_base ?? "", - litellm_params: initialData?.litellm_params ?? {}, - system_instructions: initialData?.system_instructions ?? "", - use_default_system_instructions: initialData?.use_default_system_instructions ?? true, - citations_enabled: initialData?.citations_enabled ?? true, - search_space_id: searchSpaceId, - }, - }); - - // Load default instructions when available (only for new configs) - useEffect(() => { - if ( - mode === "create" && - defaultInstructionsLoaded && - defaultInstructions?.default_system_instructions && - !form.getValues("system_instructions") - ) { - form.setValue("system_instructions", defaultInstructions.default_system_instructions); - } - }, [defaultInstructionsLoaded, defaultInstructions, mode, form]); - - const watchProvider = form.watch("provider"); - const selectedProvider = LLM_PROVIDERS.find((p) => p.value === watchProvider); - const availableModels = useMemo( - () => (dynamicModels ?? []).filter((m) => m.provider === watchProvider), - [dynamicModels, watchProvider] - ); - - const handleProviderChange = (value: string) => { - form.setValue("provider", value); - form.setValue("model_name", ""); - - // Auto-fill API base for certain providers - const provider = LLM_PROVIDERS.find((p) => p.value === value); - if (provider?.apiBase) { - form.setValue("api_base", provider.apiBase); - } - }; - - const handleFormSubmit = async (values: FormValues) => { - await onSubmit(values as LLMConfigFormData); - }; - - return ( -
- - {/* Model Configuration Section */} -
-
- Model Configuration -
- - {/* Name & Description */} -
- ( - - Configuration Name - - - - - - )} - /> - - ( - - - Description - - Optional - - - - - - - - )} - /> -
- - {/* Provider Selection */} - ( - - LLM Provider - - - - )} - /> - - {/* Custom Provider (conditional) */} - {watchProvider === "CUSTOM" && ( - ( - - Custom Provider Name - - - - - - )} - /> - )} - - {/* Model Name with Combobox */} - ( - - Model Name - - - - - - - - - - - -
- {field.value ? `Using: "${field.value}"` : "Type your model name"} -
-
- {availableModels.length > 0 && ( - - {availableModels - .filter( - (model) => - !field.value || - model.value.toLowerCase().includes(field.value.toLowerCase()) || - model.label.toLowerCase().includes(field.value.toLowerCase()) - ) - .slice(0, 50) - .map((model) => ( - { - field.onChange(value); - setModelComboboxOpen(false); - }} - className="py-2" - > - -
-
{model.label}
- {model.contextWindow && ( -
- Context: {model.contextWindow} -
- )} -
-
- ))} -
- )} -
-
-
-
- {selectedProvider?.example && ( - - Example: {selectedProvider.example} - - )} - -
- )} - /> - - {/* API Credentials */} -
- ( - - API Key - - - - {watchProvider === "OLLAMA" && ( - - Ollama doesn't require auth — enter any value - - )} - - - )} - /> - - ( - - - API Base URL - {selectedProvider?.apiBase && ( - - Auto-filled - - )} - - - - - - - )} - /> -
- - {/* Ollama Quick Actions */} - {watchProvider === "OLLAMA" && ( -
- - -
- )} -
- - {/* Advanced Parameters */} - {showAdvanced && ( - <> - - - - - - - ( - - - - - - - )} - /> - - - - )} - - {/* System Instructions & Citations Section */} - - - - - - - {/* System Instructions */} - ( - -
- Instructions for the AI - {defaultInstructions && ( - - )} -
- -