From adb857925b97fa1508075c3c0d5a83bef31f828d Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Wed, 10 Jun 2026 21:47:23 +0530 Subject: [PATCH 01/59] feat(models): add model connection persistence --- .../versions/156_add_model_connections.py | 178 +++++++++ surfsense_backend/app/db.py | 125 ++++++- surfsense_backend/app/routes/__init__.py | 2 + .../app/routes/model_connections_routes.py | 346 ++++++++++++++++++ surfsense_backend/app/schemas/__init__.py | 10 + .../app/schemas/model_connections.py | 89 +++++ .../app/services/model_connection_service.py | 209 +++++++++++ .../unit/services/test_model_connections.py | 75 ++++ 8 files changed, 1033 insertions(+), 1 deletion(-) create mode 100644 surfsense_backend/alembic/versions/156_add_model_connections.py create mode 100644 surfsense_backend/app/routes/model_connections_routes.py create mode 100644 surfsense_backend/app/schemas/model_connections.py create mode 100644 surfsense_backend/app/services/model_connection_service.py create mode 100644 surfsense_backend/tests/unit/services/test_model_connections.py diff --git a/surfsense_backend/alembic/versions/156_add_model_connections.py b/surfsense_backend/alembic/versions/156_add_model_connections.py new file mode 100644 index 000000000..0a11d7f9d --- /dev/null +++ b/surfsense_backend/alembic/versions/156_add_model_connections.py @@ -0,0 +1,178 @@ +"""add model connections + +Revision ID: 156 +Revises: 155 +""" + +from collections.abc import Sequence + +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +from alembic import op + +revision: str = "156" +down_revision: str | None = "155" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +connection_protocol = postgresql.ENUM( + "OLLAMA", + "OPENAI_COMPATIBLE", + "NATIVE", + name="connectionprotocol", + create_type=False, +) +connection_scope = postgresql.ENUM( + "GLOBAL", + "SEARCH_SPACE", + "USER", + name="connectionscope", + create_type=False, +) +model_source = postgresql.ENUM( + "DISCOVERED", + "MANUAL", + name="modelsource", + create_type=False, +) + + +def upgrade() -> None: + bind = op.get_bind() + connection_protocol.create(bind, checkfirst=True) + connection_scope.create(bind, checkfirst=True) + model_source.create(bind, checkfirst=True) + + op.create_table( + "connections", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("created_at", sa.DateTime(timezone=True), nullable=False), + sa.Column("protocol", connection_protocol, nullable=False), + sa.Column("native_provider", sa.String(length=100), nullable=True), + sa.Column("base_url", sa.String(length=500), nullable=True), + sa.Column("api_key", sa.String(), nullable=True), + sa.Column( + "extra", + postgresql.JSONB(astext_type=sa.Text()), + server_default=sa.text("'{}'::jsonb"), + nullable=False, + ), + sa.Column("scope", connection_scope, nullable=False), + sa.Column("enabled", sa.Boolean(), server_default=sa.text("true"), nullable=False), + sa.Column("search_space_id", sa.Integer(), nullable=True), + sa.Column("user_id", sa.UUID(), nullable=True), + sa.Column("last_verified_at", sa.TIMESTAMP(timezone=True), nullable=True), + sa.Column("last_status", sa.String(length=50), nullable=True), + sa.Column("last_error", sa.Text(), nullable=True), + sa.CheckConstraint( + "(scope = 'GLOBAL' AND search_space_id IS NULL AND user_id IS NULL) OR " + "(scope = 'SEARCH_SPACE' AND search_space_id IS NOT NULL AND user_id IS NOT NULL) OR " + "(scope = 'USER' AND user_id IS NOT NULL)", + name="ck_connections_scope_owner", + ), + sa.ForeignKeyConstraint( + ["search_space_id"], ["searchspaces.id"], ondelete="CASCADE" + ), + sa.ForeignKeyConstraint(["user_id"], ["user.id"], ondelete="CASCADE"), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index(op.f("ix_connections_protocol"), "connections", ["protocol"], unique=False) + op.create_index( + op.f("ix_connections_native_provider"), + "connections", + ["native_provider"], + unique=False, + ) + op.create_index(op.f("ix_connections_scope"), "connections", ["scope"], unique=False) + + op.create_table( + "models", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("created_at", sa.DateTime(timezone=True), nullable=False), + sa.Column("connection_id", sa.Integer(), nullable=False), + sa.Column("model_id", sa.String(length=255), nullable=False), + sa.Column("display_name", sa.String(length=255), nullable=True), + sa.Column( + "source", + model_source, + server_default="DISCOVERED", + nullable=False, + ), + sa.Column( + "capabilities", + postgresql.JSONB(astext_type=sa.Text()), + server_default=sa.text("'{}'::jsonb"), + nullable=False, + ), + sa.Column( + "capabilities_declared", + postgresql.JSONB(astext_type=sa.Text()), + server_default=sa.text("'{}'::jsonb"), + nullable=False, + ), + sa.Column( + "capabilities_verified", + postgresql.JSONB(astext_type=sa.Text()), + server_default=sa.text("'{}'::jsonb"), + nullable=False, + ), + sa.Column( + "capabilities_override", + postgresql.JSONB(astext_type=sa.Text()), + server_default=sa.text("'{}'::jsonb"), + nullable=False, + ), + sa.Column("embedding_dimension", sa.Integer(), nullable=True), + sa.Column("enabled", sa.Boolean(), server_default=sa.text("true"), nullable=False), + sa.Column("billing_tier", sa.String(length=50), nullable=True), + sa.Column( + "catalog", + postgresql.JSONB(astext_type=sa.Text()), + server_default=sa.text("'{}'::jsonb"), + nullable=False, + ), + sa.ForeignKeyConstraint(["connection_id"], ["connections.id"], ondelete="CASCADE"), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint( + "connection_id", "model_id", name="uq_models_connection_model_id" + ), + ) + op.create_index(op.f("ix_models_connection_id"), "models", ["connection_id"], unique=False) + op.create_index("ix_models_model_id", "models", ["model_id"], unique=False) + op.create_index(op.f("ix_models_billing_tier"), "models", ["billing_tier"], unique=False) + + op.add_column( + "searchspaces", + sa.Column("chat_model_id", sa.Integer(), nullable=True), + ) + op.add_column( + "searchspaces", + sa.Column("image_gen_model_id", sa.Integer(), nullable=True), + ) + op.add_column( + "searchspaces", + sa.Column("vision_model_id", sa.Integer(), nullable=True), + ) + + +def downgrade() -> None: + op.drop_column("searchspaces", "vision_model_id") + op.drop_column("searchspaces", "image_gen_model_id") + op.drop_column("searchspaces", "chat_model_id") + + op.drop_index(op.f("ix_models_billing_tier"), table_name="models") + op.drop_index("ix_models_model_id", table_name="models") + op.drop_index(op.f("ix_models_connection_id"), table_name="models") + op.drop_table("models") + + op.drop_index(op.f("ix_connections_scope"), table_name="connections") + op.drop_index(op.f("ix_connections_native_provider"), table_name="connections") + op.drop_index(op.f("ix_connections_protocol"), table_name="connections") + op.drop_table("connections") + + bind = op.get_bind() + model_source.drop(bind, checkfirst=True) + connection_scope.drop(bind, checkfirst=True) + connection_protocol.drop(bind, checkfirst=True) diff --git a/surfsense_backend/app/db.py b/surfsense_backend/app/db.py index 6117caecb..9756cb32f 100644 --- a/surfsense_backend/app/db.py +++ b/surfsense_backend/app/db.py @@ -280,6 +280,23 @@ class VisionProvider(StrEnum): CUSTOM = "CUSTOM" +class ConnectionProtocol(StrEnum): + OLLAMA = "OLLAMA" + OPENAI_COMPATIBLE = "OPENAI_COMPATIBLE" + NATIVE = "NATIVE" + + +class ConnectionScope(StrEnum): + GLOBAL = "GLOBAL" + SEARCH_SPACE = "SEARCH_SPACE" + USER = "USER" + + +class ModelSource(StrEnum): + DISCOVERED = "DISCOVERED" + MANUAL = "MANUAL" + + class LogLevel(StrEnum): DEBUG = "DEBUG" INFO = "INFO" @@ -1642,6 +1659,79 @@ class Report(BaseModel, TimestampMixin): thread = relationship("NewChatThread") +class Connection(BaseModel, TimestampMixin): + __tablename__ = "connections" + + protocol = Column(SQLAlchemyEnum(ConnectionProtocol), nullable=False, index=True) + native_provider = Column(String(100), nullable=True, index=True) + base_url = Column(String(500), nullable=True) + api_key = Column(String, nullable=True) + extra = Column(JSONB, nullable=False, default=dict, server_default="{}") + scope = Column(SQLAlchemyEnum(ConnectionScope), nullable=False, index=True) + enabled = Column(Boolean, nullable=False, default=True, server_default="true") + + search_space_id = Column( + Integer, ForeignKey("searchspaces.id", ondelete="CASCADE"), nullable=True + ) + user_id = Column( + UUID(as_uuid=True), ForeignKey("user.id", ondelete="CASCADE"), nullable=True + ) + + last_verified_at = Column(TIMESTAMP(timezone=True), nullable=True) + last_status = Column(String(50), nullable=True) + last_error = Column(Text, nullable=True) + + search_space = relationship("SearchSpace", back_populates="connections") + user = relationship("User", back_populates="connections") + models = relationship( + "Model", + back_populates="connection", + order_by="Model.id", + cascade="all, delete-orphan", + passive_deletes=True, + ) + + __table_args__ = ( + CheckConstraint( + "(scope = 'GLOBAL' AND search_space_id IS NULL AND user_id IS NULL) OR " + "(scope = 'SEARCH_SPACE' AND search_space_id IS NOT NULL AND user_id IS NOT NULL) OR " + "(scope = 'USER' AND user_id IS NOT NULL)", + name="ck_connections_scope_owner", + ), + ) + + +class Model(BaseModel, TimestampMixin): + __tablename__ = "models" + + connection_id = Column( + Integer, ForeignKey("connections.id", ondelete="CASCADE"), nullable=False, index=True + ) + model_id = Column(String(255), nullable=False) + display_name = Column(String(255), nullable=True) + source = Column( + SQLAlchemyEnum(ModelSource), + nullable=False, + default=ModelSource.DISCOVERED, + server_default=ModelSource.DISCOVERED.value, + ) + capabilities = Column(JSONB, nullable=False, default=dict, server_default="{}") + capabilities_declared = Column(JSONB, nullable=False, default=dict, server_default="{}") + capabilities_verified = Column(JSONB, nullable=False, default=dict, server_default="{}") + capabilities_override = Column(JSONB, nullable=False, default=dict, server_default="{}") + embedding_dimension = Column(Integer, nullable=True) + enabled = Column(Boolean, nullable=False, default=True, server_default="true") + billing_tier = Column(String(50), nullable=True, index=True) + catalog = Column(JSONB, nullable=False, default=dict, server_default="{}") + + connection = relationship("Connection", back_populates="models") + + __table_args__ = ( + UniqueConstraint("connection_id", "model_id", name="uq_models_connection_model_id"), + Index("ix_models_model_id", "model_id"), + ) + + class ImageGenerationConfig(BaseModel, TimestampMixin): """ Dedicated configuration table for image generation models. @@ -1794,7 +1884,7 @@ class SearchSpace(BaseModel, TimestampMixin): # - Positive IDs: Custom configs from DB (NewLLMConfig table) agent_llm_id = Column( Integer, nullable=True, default=0 - ) # For agent/chat operations, defaults to Auto mode + ) # For chat operations, defaults to Auto mode image_generation_config_id = Column( Integer, nullable=True, default=0 ) # For image generation, defaults to Auto mode @@ -1802,6 +1892,22 @@ class SearchSpace(BaseModel, TimestampMixin): Integer, nullable=True, default=0 ) # For vision/screenshot analysis, defaults to Auto mode + # New connection/model role bindings. These supersede the legacy config + # columns above without removing them in this PR. + # Note: ID values preserve the existing convention: + # - 0: Auto mode + # - Negative IDs: Global virtual models from global_llm_config.yaml + # - Positive IDs: User/search-space models from the models table + chat_model_id = Column( + Integer, nullable=True, default=0 + ) # For agent/chat operations, defaults to Auto mode + image_gen_model_id = Column( + Integer, nullable=True, default=0 + ) # For image generation, defaults to Auto mode when eligible + vision_model_id = Column( + Integer, nullable=True, default=0 + ) # For vision/screenshot analysis, defaults to Auto mode + ai_file_sort_enabled = Column( Boolean, nullable=False, default=False, server_default="false" ) @@ -1889,6 +1995,13 @@ class SearchSpace(BaseModel, TimestampMixin): order_by="VisionLLMConfig.id", cascade="all, delete-orphan", ) + connections = relationship( + "Connection", + back_populates="search_space", + order_by="Connection.id", + cascade="all, delete-orphan", + passive_deletes=True, + ) automations = relationship( "Automation", @@ -2429,6 +2542,11 @@ if config.AUTH_TYPE == "GOOGLE": back_populates="user", passive_deletes=True, ) + connections = relationship( + "Connection", + back_populates="user", + passive_deletes=True, + ) # Automations created by this user automations = relationship( @@ -2568,6 +2686,11 @@ else: back_populates="user", passive_deletes=True, ) + connections = relationship( + "Connection", + back_populates="user", + passive_deletes=True, + ) # Automations created by this user automations = relationship( diff --git a/surfsense_backend/app/routes/__init__.py b/surfsense_backend/app/routes/__init__.py index 5cc029884..244208550 100644 --- a/surfsense_backend/app/routes/__init__.py +++ b/surfsense_backend/app/routes/__init__.py @@ -42,6 +42,7 @@ from .linear_add_connector_route import router as linear_add_connector_router from .logs_routes import router as logs_router from .luma_add_connector_route import router as luma_add_connector_router from .mcp_oauth_route import router as mcp_oauth_router +from .model_connections_routes import router as model_connections_router from .memory_routes import router as memory_router from .model_list_routes import router as model_list_router from .new_chat_routes import router as new_chat_router @@ -117,6 +118,7 @@ router.include_router(confluence_add_connector_router) router.include_router(clickup_add_connector_router) router.include_router(dropbox_add_connector_router) router.include_router(new_llm_config_router) # LLM configs with prompt configuration +router.include_router(model_connections_router) # Connection-centric model catalog router.include_router(model_list_router) # Dynamic model catalogue from OpenRouter router.include_router(logs_router) router.include_router(circleback_webhook_router) # Circleback meeting webhooks diff --git a/surfsense_backend/app/routes/model_connections_routes.py b/surfsense_backend/app/routes/model_connections_routes.py new file mode 100644 index 000000000..69910183d --- /dev/null +++ b/surfsense_backend/app/routes/model_connections_routes.py @@ -0,0 +1,346 @@ +import logging + +from fastapi import APIRouter, Depends, HTTPException +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import selectinload + +from app.config import config +from app.db import ( + Connection, + ConnectionScope, + Model, + Permission, + SearchSpace, + User, + get_async_session, +) +from app.schemas import ( + ConnectionCreate, + ConnectionRead, + ConnectionUpdate, + ModelRead, + ModelRolesRead, + ModelRolesUpdate, + ModelUpdate, + VerifyConnectionResponse, +) +from app.services.model_connection_service import ( + discover_models, + persist_verification, + test_model, +) +from app.users import current_active_user +from app.utils.rbac import check_permission + +router = APIRouter() +logger = logging.getLogger(__name__) + + +def _model_read(model: Model | dict) -> ModelRead: + return ModelRead.model_validate(model) + + +def _connection_read(conn: Connection | dict, models: list[Model | dict] | None = None) -> ConnectionRead: + if isinstance(conn, dict): + payload = { + **conn, + "has_api_key": bool(conn.get("api_key")), + "api_key": None, + "models": [_model_read(model) for model in (models or [])], + } + payload.pop("api_key", None) + return ConnectionRead.model_validate(payload) + + return ConnectionRead( + id=conn.id, + protocol=conn.protocol, + native_provider=conn.native_provider, + base_url=conn.base_url, + extra=conn.extra or {}, + scope=conn.scope, + search_space_id=conn.search_space_id, + user_id=conn.user_id, + enabled=conn.enabled, + has_api_key=bool(conn.api_key), + last_verified_at=conn.last_verified_at, + last_status=conn.last_status, + last_error=conn.last_error, + models=[_model_read(model) for model in (models or conn.models or [])], + created_at=conn.created_at, + ) + + +async def _get_search_space(session: AsyncSession, search_space_id: int) -> SearchSpace: + result = await session.execute(select(SearchSpace).where(SearchSpace.id == search_space_id)) + search_space = result.scalars().first() + if not search_space: + raise HTTPException(status_code=404, detail="Search space not found") + return search_space + + +async def _load_connection(session: AsyncSession, connection_id: int) -> Connection: + result = await session.execute( + select(Connection) + .options(selectinload(Connection.models)) + .where(Connection.id == connection_id) + ) + conn = result.scalars().first() + if not conn: + raise HTTPException(status_code=404, detail="Connection not found") + return conn + + +async def _assert_connection_access( + session: AsyncSession, + user: User, + conn: Connection, + permission: str = Permission.LLM_CONFIGS_CREATE.value, +) -> None: + if conn.search_space_id: + await check_permission( + session, + user, + conn.search_space_id, + permission, + "You don't have permission to manage model connections in this search space", + ) + return + if conn.user_id != user.id: + raise HTTPException(status_code=403, detail="Connection does not belong to user") + + +@router.get("/global-model-connections", response_model=list[ConnectionRead]) +async def list_global_connections(user: User = Depends(current_active_user)): + del user + models_by_connection: dict[int, list[dict]] = {} + for model in config.GLOBAL_MODELS: + models_by_connection.setdefault(model["connection_id"], []).append(model) + return [ + _connection_read(conn, models_by_connection.get(conn["id"], [])) + for conn in config.GLOBAL_CONNECTIONS + ] + + +@router.get("/model-connections", response_model=list[ConnectionRead]) +async def list_connections( + search_space_id: int | None = None, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +): + stmt = select(Connection).options(selectinload(Connection.models)) + if search_space_id is not None: + await check_permission( + session, + user, + search_space_id, + Permission.LLM_CONFIGS_CREATE.value, + "You don't have permission to view model connections in this search space", + ) + stmt = stmt.where(Connection.search_space_id == search_space_id) + else: + stmt = stmt.where(Connection.user_id == user.id) + result = await session.execute(stmt.order_by(Connection.id)) + return [_connection_read(conn) for conn in result.scalars().all()] + + +@router.post("/model-connections", response_model=ConnectionRead) +async def create_connection( + data: ConnectionCreate, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +): + if data.scope == ConnectionScope.GLOBAL: + raise HTTPException(status_code=400, detail="GLOBAL connections are YAML-only") + if data.scope == ConnectionScope.SEARCH_SPACE: + if data.search_space_id is None: + raise HTTPException(status_code=400, detail="search_space_id is required") + await check_permission( + session, + user, + data.search_space_id, + Permission.LLM_CONFIGS_CREATE.value, + "You don't have permission to create model connections in this search space", + ) + conn = Connection( + **data.model_dump(exclude={"search_space_id"}), + search_space_id=data.search_space_id if data.scope == ConnectionScope.SEARCH_SPACE else None, + user_id=user.id, + ) + session.add(conn) + await session.commit() + await session.refresh(conn) + return _connection_read(conn) + + +@router.put("/model-connections/{connection_id}", response_model=ConnectionRead) +async def update_connection( + connection_id: int, + data: ConnectionUpdate, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +): + conn = await _load_connection(session, connection_id) + await _assert_connection_access(session, user, conn, Permission.LLM_CONFIGS_UPDATE.value) + for key, value in data.model_dump(exclude_unset=True).items(): + setattr(conn, key, value) + await session.commit() + await session.refresh(conn) + return _connection_read(conn) + + +@router.delete("/model-connections/{connection_id}") +async def delete_connection( + connection_id: int, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +): + conn = await _load_connection(session, connection_id) + await _assert_connection_access(session, user, conn, Permission.LLM_CONFIGS_DELETE.value) + await session.delete(conn) + await session.commit() + return {"status": "deleted"} + + +@router.post("/model-connections/{connection_id}/verify", response_model=VerifyConnectionResponse) +async def verify_model_connection( + connection_id: int, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +): + conn = await _load_connection(session, connection_id) + await _assert_connection_access(session, user, conn, Permission.LLM_CONFIGS_CREATE.value) + result = await persist_verification(conn) + await session.commit() + return VerifyConnectionResponse(status=result.status, ok=result.ok, message=result.message) + + +@router.post("/model-connections/{connection_id}/discover", response_model=list[ModelRead]) +async def discover_connection_models( + connection_id: int, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +): + conn = await _load_connection(session, connection_id) + await _assert_connection_access(session, user, conn, Permission.LLM_CONFIGS_CREATE.value) + discovered = await discover_models(conn) + by_model_id = {model.model_id: model for model in conn.models} + for item in discovered: + db_model = by_model_id.get(item["model_id"]) + if db_model is None: + db_model = Model( + connection_id=conn.id, + model_id=item["model_id"], + display_name=item.get("display_name"), + source=item["source"], + capabilities=item["capabilities"], + capabilities_declared=item["capabilities"], + capabilities_verified={}, + capabilities_override={}, + enabled=False, + catalog=item.get("metadata") or {}, + ) + session.add(db_model) + else: + db_model.display_name = item.get("display_name") or db_model.display_name + db_model.capabilities_declared = item["capabilities"] + db_model.capabilities = { + **item["capabilities"], + **(db_model.capabilities_override or {}), + } + db_model.catalog = item.get("metadata") or db_model.catalog + await session.commit() + await session.refresh(conn) + return [_model_read(model) for model in conn.models] + + +@router.put("/models/{model_id}", response_model=ModelRead) +async def update_model( + model_id: int, + data: ModelUpdate, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +): + result = await session.execute( + select(Model).options(selectinload(Model.connection)).where(Model.id == model_id) + ) + model = result.scalars().first() + if not model: + raise HTTPException(status_code=404, detail="Model not found") + await _assert_connection_access(session, user, model.connection, Permission.LLM_CONFIGS_UPDATE.value) + update = data.model_dump(exclude_unset=True) + for key, value in update.items(): + setattr(model, key, value) + if "capabilities_override" in update: + model.capabilities = { + **(model.capabilities_declared or {}), + **(model.capabilities_override or {}), + } + await session.commit() + await session.refresh(model) + return _model_read(model) + + +@router.post("/models/{model_id}/test", response_model=VerifyConnectionResponse) +async def test_connection_model( + model_id: int, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +): + result = await session.execute( + select(Model).options(selectinload(Model.connection)).where(Model.id == model_id) + ) + model = result.scalars().first() + if not model: + raise HTTPException(status_code=404, detail="Model not found") + await _assert_connection_access(session, user, model.connection, Permission.LLM_CONFIGS_UPDATE.value) + result = await test_model(model.connection, model) + await session.commit() + return VerifyConnectionResponse(status=result.status, ok=result.ok, message=result.message) + + +@router.get("/search-spaces/{search_space_id}/model-roles", response_model=ModelRolesRead) +async def get_model_roles( + search_space_id: int, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +): + await check_permission( + session, + user, + search_space_id, + Permission.LLM_CONFIGS_CREATE.value, + "You don't have permission to view model roles in this search space", + ) + search_space = await _get_search_space(session, search_space_id) + return ModelRolesRead( + chat_model_id=search_space.chat_model_id, + vision_model_id=search_space.vision_model_id, + image_gen_model_id=search_space.image_gen_model_id, + ) + + +@router.put("/search-spaces/{search_space_id}/model-roles", response_model=ModelRolesRead) +async def update_model_roles( + search_space_id: int, + data: ModelRolesUpdate, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +): + await check_permission( + session, + user, + search_space_id, + Permission.LLM_CONFIGS_UPDATE.value, + "You don't have permission to update model roles in this search space", + ) + search_space = await _get_search_space(session, search_space_id) + for key, value in data.model_dump(exclude_unset=True).items(): + setattr(search_space, key, value) + await session.commit() + await session.refresh(search_space) + return ModelRolesRead( + chat_model_id=search_space.chat_model_id, + vision_model_id=search_space.vision_model_id, + image_gen_model_id=search_space.image_gen_model_id, + ) diff --git a/surfsense_backend/app/schemas/__init__.py b/surfsense_backend/app/schemas/__init__.py index fdf34672b..c14671c99 100644 --- a/surfsense_backend/app/schemas/__init__.py +++ b/surfsense_backend/app/schemas/__init__.py @@ -44,6 +44,16 @@ from .image_generation import ( ImageGenerationRead, ) from .logs import LogBase, LogCreate, LogFilter, LogRead, LogUpdate +from .model_connections import ( + ConnectionCreate, + ConnectionRead, + ConnectionUpdate, + ModelRead, + ModelRolesRead, + ModelRolesUpdate, + ModelUpdate, + VerifyConnectionResponse, +) from .new_chat import ( ChatMessage, NewChatMessageAppend, diff --git a/surfsense_backend/app/schemas/model_connections.py b/surfsense_backend/app/schemas/model_connections.py new file mode 100644 index 000000000..731064375 --- /dev/null +++ b/surfsense_backend/app/schemas/model_connections.py @@ -0,0 +1,89 @@ +import uuid +from datetime import datetime +from typing import Any + +from pydantic import BaseModel, ConfigDict, Field + +from app.db import ConnectionProtocol, ConnectionScope, ModelSource + + +class ModelRead(BaseModel): + id: int + connection_id: int + model_id: str + display_name: str | None = None + source: ModelSource | str + capabilities: dict[str, Any] + capabilities_declared: dict[str, Any] = Field(default_factory=dict) + capabilities_verified: dict[str, Any] = Field(default_factory=dict) + capabilities_override: dict[str, Any] = Field(default_factory=dict) + embedding_dimension: int | None = None + enabled: bool + billing_tier: str | None = None + catalog: dict[str, Any] = Field(default_factory=dict) + created_at: datetime | None = None + + model_config = ConfigDict(from_attributes=True) + + +class ConnectionRead(BaseModel): + id: int + protocol: ConnectionProtocol | str + native_provider: str | None = None + base_url: str | None = None + extra: dict[str, Any] = Field(default_factory=dict) + scope: ConnectionScope | str + search_space_id: int | None = None + user_id: uuid.UUID | None = None + enabled: bool + has_api_key: bool + last_verified_at: datetime | None = None + last_status: str | None = None + last_error: str | None = None + models: list[ModelRead] = Field(default_factory=list) + created_at: datetime | None = None + + model_config = ConfigDict(from_attributes=True) + + +class ConnectionCreate(BaseModel): + protocol: ConnectionProtocol + native_provider: str | None = None + base_url: str | None = Field(None, max_length=500) + api_key: str | None = None + extra: dict[str, Any] = Field(default_factory=dict) + scope: ConnectionScope = ConnectionScope.SEARCH_SPACE + search_space_id: int | None = None + enabled: bool = True + + +class ConnectionUpdate(BaseModel): + native_provider: str | None = None + base_url: str | None = Field(None, max_length=500) + api_key: str | None = None + extra: dict[str, Any] | None = None + enabled: bool | None = None + + +class ModelUpdate(BaseModel): + display_name: str | None = Field(None, max_length=255) + enabled: bool | None = None + capabilities_override: dict[str, Any] | None = None + + +class VerifyConnectionResponse(BaseModel): + status: str + ok: bool + message: str = "" + + +class ModelRolesRead(BaseModel): + chat_model_id: int | None = 0 + vision_model_id: int | None = 0 + image_gen_model_id: int | None = 0 + + +class ModelRolesUpdate(BaseModel): + chat_model_id: int | None = None + vision_model_id: int | None = None + image_gen_model_id: int | None = None diff --git a/surfsense_backend/app/services/model_connection_service.py b/surfsense_backend/app/services/model_connection_service.py new file mode 100644 index 000000000..81090acaf --- /dev/null +++ b/surfsense_backend/app/services/model_connection_service.py @@ -0,0 +1,209 @@ +"""Connection verification, model discovery, and capability probing.""" + +from __future__ import annotations + +import contextlib +import logging +from dataclasses import dataclass +from datetime import UTC, datetime +from typing import Any + +import httpx +import litellm + +from app.db import Connection, ConnectionProtocol, Model, ModelSource +from app.services.model_resolver import ensure_v1, to_litellm + +logger = logging.getLogger(__name__) + +VERIFY_TIMEOUT_SECONDS = 8.0 +DISCOVERY_TIMEOUT_SECONDS = 15.0 +TEST_TIMEOUT_SECONDS = 30.0 + + +@dataclass(frozen=True) +class VerifyResult: + status: str + ok: bool + message: str = "" + + +def _auth_headers(conn: Connection) -> dict[str, str]: + if not conn.api_key: + return {} + return {"Authorization": f"Bearer {conn.api_key}"} + + +def _docker_hint(url: str | None, exc_or_status: Any) -> str: + raw = str(exc_or_status) + if not url: + return raw + if "localhost" in url or "127.0.0.1" in url: + return ( + f"{raw}. The backend is running inside Docker; localhost means the " + "backend container. Use host.docker.internal and make sure the model " + "server listens on 0.0.0.0." + ) + if "host.docker.internal" in url and ("refused" in raw.lower() or "connect" in raw.lower()): + return ( + f"{raw}. The host is reachable only if your local model server is " + "listening on 0.0.0.0. On Linux Docker, add " + "`host.docker.internal:host-gateway` to extra_hosts." + ) + return raw + + +async def verify_connection(conn: Connection) -> VerifyResult: + if not conn.base_url and conn.protocol in ( + ConnectionProtocol.OLLAMA, + ConnectionProtocol.OPENAI_COMPATIBLE, + ): + return VerifyResult("UNREACHABLE", False, "Base URL is required.") + + if conn.protocol == ConnectionProtocol.OLLAMA: + url = f"{conn.base_url.rstrip('/')}/api/version" + elif conn.protocol == ConnectionProtocol.OPENAI_COMPATIBLE: + url = f"{ensure_v1(conn.base_url)}/models" + else: + # Native providers do not share one cheap health endpoint. The model + # probe exercises the real path and is the authoritative check. + return VerifyResult("OK", True, "Native provider configuration accepted.") + + try: + async with httpx.AsyncClient(timeout=VERIFY_TIMEOUT_SECONDS) as client: + response = await client.get(url, headers=_auth_headers(conn)) + if response.status_code in (401, 403): + return VerifyResult("AUTH_FAILED", False, "Authentication failed.") + if response.status_code == 404: + if conn.protocol == ConnectionProtocol.OLLAMA and url.endswith("/v1/models"): + message = "Ollama native API should not use /v1." + elif conn.protocol == ConnectionProtocol.OPENAI_COMPATIBLE: + message = "OpenAI-compatible servers should expose /v1/models." + else: + message = "Endpoint returned 404." + return VerifyResult("NOT_FOUND", False, message) + response.raise_for_status() + return VerifyResult("OK", True, "Connection verified.") + except httpx.ConnectError as exc: + return VerifyResult("UNREACHABLE", False, _docker_hint(conn.base_url, exc)) + except httpx.TimeoutException as exc: + return VerifyResult("UNREACHABLE", False, f"Connection timed out: {exc}") + except httpx.HTTPError as exc: + return VerifyResult("UNREACHABLE", False, _docker_hint(conn.base_url, exc)) + + +async def persist_verification(conn: Connection) -> VerifyResult: + result = await verify_connection(conn) + conn.last_verified_at = datetime.now(UTC) + conn.last_status = result.status + conn.last_error = "" if result.ok else result.message + return result + + +def _litellm_capabilities(model_string: str, model_id: str) -> dict[str, bool]: + capabilities = { + "chat": True, + "vision": False, + "tools": False, + "image_gen": False, + "embedding": False, + } + with contextlib.suppress(Exception): + capabilities["vision"] = bool(litellm.supports_vision(model=model_string)) + with contextlib.suppress(Exception): + capabilities["tools"] = bool(litellm.supports_function_calling(model=model_string)) + try: + info = litellm.model_cost.get(model_string) or litellm.model_cost.get(model_id) or {} + mode = str(info.get("mode") or "") + capabilities["embedding"] = mode == "embedding" + capabilities["image_gen"] = mode in {"image_generation", "image_generation_model"} + except Exception: + pass + return capabilities + + +def derive_capabilities(conn: Connection, model_id: str, metadata: dict | None = None) -> dict[str, bool]: + metadata = metadata or {} + if conn.protocol == ConnectionProtocol.OLLAMA: + caps = metadata.get("capabilities") or [] + capabilities = { + "chat": True, + "vision": "vision" in caps, + "tools": False, + "image_gen": False, + "embedding": "embedding" in caps, + } + return capabilities + + model_string, _ = to_litellm(conn, model_id) + return _litellm_capabilities(model_string, model_id) + + +async def discover_models(conn: Connection) -> list[dict[str, Any]]: + if conn.protocol == ConnectionProtocol.OLLAMA: + url = f"{conn.base_url.rstrip('/')}/api/tags" + async with httpx.AsyncClient(timeout=DISCOVERY_TIMEOUT_SECONDS) as client: + response = await client.get(url, headers=_auth_headers(conn)) + response.raise_for_status() + models = response.json().get("models", []) + return [ + { + "model_id": item.get("model") or item.get("name"), + "display_name": item.get("name") or item.get("model"), + "source": ModelSource.DISCOVERED, + "capabilities": derive_capabilities(conn, item.get("model") or item.get("name"), item), + "metadata": item, + } + for item in models + if item.get("model") or item.get("name") + ] + + if conn.protocol == ConnectionProtocol.OPENAI_COMPATIBLE: + url = f"{ensure_v1(conn.base_url)}/models" + async with httpx.AsyncClient(timeout=DISCOVERY_TIMEOUT_SECONDS) as client: + response = await client.get(url, headers=_auth_headers(conn)) + response.raise_for_status() + models = response.json().get("data", []) + return [ + { + "model_id": item.get("id"), + "display_name": item.get("name") or item.get("id"), + "source": ModelSource.DISCOVERED, + "capabilities": derive_capabilities(conn, item.get("id"), item), + "metadata": item, + } + for item in models + if item.get("id") + ] + + # Native providers rely on curated/global catalog entries or manual rows. + return [] + + +async def test_model(conn: Connection, model: Model) -> VerifyResult: + model_string, kwargs = to_litellm(conn, model.model_id) + try: + await litellm.acompletion( + model=model_string, + messages=[{"role": "user", "content": "Hello"}], + timeout=TEST_TIMEOUT_SECONDS, + **kwargs, + ) + except Exception as exc: + return VerifyResult("UNREACHABLE", False, str(exc)) + + model.capabilities_verified = { + **(model.capabilities_verified or {}), + "chat": True, + } + return VerifyResult("OK", True, "Model test succeeded.") + + +__all__ = [ + "VerifyResult", + "derive_capabilities", + "discover_models", + "persist_verification", + "test_model", + "verify_connection", +] diff --git a/surfsense_backend/tests/unit/services/test_model_connections.py b/surfsense_backend/tests/unit/services/test_model_connections.py new file mode 100644 index 000000000..98042501b --- /dev/null +++ b/surfsense_backend/tests/unit/services/test_model_connections.py @@ -0,0 +1,75 @@ +from app.services.global_model_catalog import materialize_global_model_catalog +from app.services.model_resolver import ensure_v1, to_litellm + + +def test_openai_compatible_resolver_normalizes_v1() -> None: + model, kwargs = to_litellm( + { + "protocol": "OPENAI_COMPATIBLE", + "base_url": "http://host.docker.internal:1234", + "api_key": "local-key", + "extra": {}, + }, + "qwen/qwen3", + ) + + assert model == "openai/qwen/qwen3" + assert kwargs["api_base"] == "http://host.docker.internal:1234/v1" + assert kwargs["api_key"] == "local-key" + assert ensure_v1("http://example.com/v1") == "http://example.com/v1" + + +def test_ollama_resolver_uses_native_api_base() -> None: + model, kwargs = to_litellm( + { + "protocol": "OLLAMA", + "base_url": "http://host.docker.internal:11434", + "api_key": None, + "extra": {}, + }, + "llama3.2", + ) + + assert model == "ollama_chat/llama3.2" + assert kwargs["api_base"] == "http://host.docker.internal:11434" + + +def test_global_materialization_preserves_tier_and_keeps_key_server_side() -> None: + connections, models = materialize_global_model_catalog( + chat_configs=[ + { + "id": -101, + "name": "OpenRouter Free", + "provider": "OPENROUTER", + "model_name": "meta-llama/llama-3.1-8b-instruct:free", + "api_key": "sk-global-secret", + "billing_tier": "free", + "anonymous_enabled": True, + "seo_enabled": True, + "rpm": 10, + "tpm": 1000, + }, + { + "id": -102, + "name": "OpenRouter Premium", + "provider": "OPENROUTER", + "model_name": "anthropic/claude-sonnet-4", + "api_key": "sk-global-secret", + "billing_tier": "premium", + }, + ], + vision_configs=[], + image_configs=[], + ) + + assert len(connections) == 1 + assert connections[0]["api_key"] == "sk-global-secret" + assert {model["billing_tier"] for model in models} == {"free", "premium"} + assert models[0]["catalog"]["anonymous_enabled"] is True + assert models[0]["catalog"]["rpm"] == 10 + + public_connections = [ + {key: value for key, value in connection.items() if key != "api_key"} + for connection in connections + ] + assert "sk-" not in repr(public_connections) From 8b59ca59c11e6105d6b2558b73657766bf726869 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Wed, 10 Jun 2026 21:47:42 +0530 Subject: [PATCH 02/59] feat(models): add provider catalog and resolver --- surfsense_backend/app/config/__init__.py | 33 +++- .../app/config/global_llm_config.example.yaml | 45 +++--- .../app/schemas/new_llm_config.py | 2 +- .../app/services/global_model_catalog.py | 142 ++++++++++++++++ .../app/services/model_resolver.py | 152 ++++++++++++++++++ .../app/services/provider_capabilities.py | 37 +---- 6 files changed, 355 insertions(+), 56 deletions(-) create mode 100644 surfsense_backend/app/services/global_model_catalog.py create mode 100644 surfsense_backend/app/services/model_resolver.py diff --git a/surfsense_backend/app/config/__init__.py b/surfsense_backend/app/config/__init__.py index 75af17d11..b8addb45d 100644 --- a/surfsense_backend/app/config/__init__.py +++ b/surfsense_backend/app/config/__init__.py @@ -389,10 +389,28 @@ def initialize_openrouter_integration(): ) except Exception as e: print(f"Warning: Failed to inject OpenRouter vision-LLM configs: {e}") + + refresh_global_model_catalog() except Exception as e: print(f"Warning: Failed to initialize OpenRouter integration: {e}") +def materialize_global_configs(): + from app.services.global_model_catalog import materialize_global_model_catalog + + return materialize_global_model_catalog( + chat_configs=getattr(config, "GLOBAL_LLM_CONFIGS", []), + vision_configs=getattr(config, "GLOBAL_VISION_LLM_CONFIGS", []), + image_configs=getattr(config, "GLOBAL_IMAGE_GEN_CONFIGS", []), + ) + + +def refresh_global_model_catalog(): + connections, models = materialize_global_configs() + config.GLOBAL_CONNECTIONS = connections + config.GLOBAL_MODELS = models + + def initialize_pricing_registration(): """ Teach LiteLLM the per-token cost of every deployment in @@ -723,7 +741,7 @@ class Config: os.getenv("QUOTA_DEFAULT_IMAGE_RESERVE_MICROS", "50000") ) - # Per-podcast reservation (in micro-USD). One agent LLM call generating + # Per-podcast reservation (in micro-USD). One chat model call generating # a transcript, typically 5k-20k completion tokens. $0.20 covers a long # premium-model run. Tune via env. QUOTA_DEFAULT_PODCAST_RESERVE_MICROS = int( @@ -849,6 +867,19 @@ class Config: # Router settings for Vision LLM Auto mode VISION_LLM_ROUTER_SETTINGS = load_vision_llm_router_settings() + # Virtual GLOBAL connection/model catalog. This is server-only metadata + # derived from global_llm_config.yaml; GLOBAL keys are not stored in DB. + from app.services.global_model_catalog import ( + materialize_global_model_catalog as _materialize_global_model_catalog, + ) + + GLOBAL_CONNECTIONS, GLOBAL_MODELS = _materialize_global_model_catalog( + chat_configs=GLOBAL_LLM_CONFIGS, + vision_configs=GLOBAL_VISION_LLM_CONFIGS, + image_configs=GLOBAL_IMAGE_GEN_CONFIGS, + ) + del _materialize_global_model_catalog + # OpenRouter Integration settings (optional) OPENROUTER_INTEGRATION_SETTINGS = load_openrouter_integration_settings() diff --git a/surfsense_backend/app/config/global_llm_config.example.yaml b/surfsense_backend/app/config/global_llm_config.example.yaml index 1c09a91ac..b0eee6458 100644 --- a/surfsense_backend/app/config/global_llm_config.example.yaml +++ b/surfsense_backend/app/config/global_llm_config.example.yaml @@ -7,8 +7,9 @@ # NOTE: The example API keys below are placeholders and won't work. # Replace them with your actual API keys to enable global configurations. # -# These configurations will be available to all users as a convenient option -# Users can choose to use these global configs or add their own +# These configurations are materialized as server-owned GLOBAL connections/models +# and become available on the Models page. Users can choose hosted/global models +# or add their own BYOK/local connections. # # AUTO MODE (Recommended): # - Auto mode (ID: 0) uses LiteLLM Router to automatically load balance across all global configs @@ -16,9 +17,12 @@ # - New users are automatically assigned Auto mode by default # - Configure router_settings below to customize the load balancing behavior # -# Structure matches NewLLMConfig: -# - Model configuration (provider, model_name, api_key, etc.) -# - Prompt configuration (system_instructions, citations_enabled) +# Static config shape: +# - Connection fields: provider, api_key, api_base, api_version +# - Model fields: model_name, billing_tier, rpm/tpm, litellm_params +# - Prompt defaults: system_instructions, citations_enabled +# IDs share one GLOBAL model namespace across chat, vision, and image generation. +# Suggested ranges: chat -1..-999, vision -1001..-1999, image -2001..-2999. # # COST-BASED PREMIUM CREDITS: # Each premium config bills the user's USD-credit balance based on the @@ -327,7 +331,7 @@ openrouter_integration: quota_reserve_tokens: 4000 # id_offset: base negative ID for dynamically generated configs. # Model IDs are derived deterministically via BLAKE2b so they survive - # catalogue churn. Must not overlap with your static global_llm_configs IDs. + # catalogue churn. Must not overlap with any static GLOBAL model IDs. id_offset: -10000 # refresh_interval_hours: how often to re-fetch models from OpenRouter (0 = startup only) refresh_interval_hours: 24 @@ -351,8 +355,8 @@ openrouter_integration: # Image generation + vision LLM emission are OPT-IN. OpenRouter's catalogue # contains hundreds of image- and vision-capable models; turning these on - # injects them into the global Image-Generation / Vision-LLM model - # selectors alongside any static configs. Tier (free/premium) is derived + # injects them into the global image-generation / vision model lists + # alongside any static configs. Tier (free/premium) is derived # per model the same way it is for chat (`:free` suffix or zero pricing). # When a user picks a premium image/vision model the call debits the # shared $5 USD-cost-based premium credit pool — so leaving these off @@ -384,7 +388,7 @@ image_generation_router_settings: global_image_generation_configs: # Example: OpenAI DALL-E 3 - - id: -1 + - id: -2001 name: "Global DALL-E 3" description: "OpenAI's DALL-E 3 for high-quality image generation" provider: "OPENAI" @@ -395,7 +399,7 @@ global_image_generation_configs: litellm_params: {} # Example: OpenAI GPT Image 1 - - id: -2 + - id: -2002 name: "Global GPT Image 1" description: "OpenAI's GPT Image 1 model" provider: "OPENAI" @@ -406,7 +410,7 @@ global_image_generation_configs: litellm_params: {} # Example: Azure OpenAI DALL-E 3 - - id: -3 + - id: -2003 name: "Global Azure DALL-E 3" description: "Azure-hosted DALL-E 3 deployment" provider: "AZURE_OPENAI" @@ -419,7 +423,7 @@ global_image_generation_configs: base_model: "dall-e-3" # Example: OpenRouter Gemini Image Generation - # - id: -4 + # - id: -2004 # name: "Global Gemini Image Gen" # description: "Google Gemini image generation via OpenRouter" # provider: "OPENROUTER" @@ -448,7 +452,7 @@ vision_llm_router_settings: global_vision_llm_configs: # Example: OpenAI GPT-4o (recommended for vision) - - id: -1 + - id: -1001 name: "Global GPT-4o Vision" description: "OpenAI's GPT-4o with strong vision capabilities" provider: "OPENAI" @@ -462,7 +466,7 @@ global_vision_llm_configs: max_tokens: 1000 # Example: Google Gemini 2.0 Flash - - id: -2 + - id: -1002 name: "Global Gemini 2.0 Flash" description: "Google's fast vision model with large context" provider: "GOOGLE" @@ -476,7 +480,7 @@ global_vision_llm_configs: max_tokens: 1000 # Example: Anthropic Claude 3.5 Sonnet - - id: -3 + - id: -1003 name: "Global Claude 3.5 Sonnet Vision" description: "Anthropic's Claude 3.5 Sonnet with vision support" provider: "ANTHROPIC" @@ -490,7 +494,7 @@ global_vision_llm_configs: max_tokens: 1000 # Example: Azure OpenAI GPT-4o - # - id: -4 + # - id: -1004 # name: "Global Azure GPT-4o Vision" # description: "Azure-hosted GPT-4o for vision analysis" # provider: "AZURE_OPENAI" @@ -507,8 +511,9 @@ global_vision_llm_configs: # Notes: # - ID 0 is reserved for "Auto" mode - uses LiteLLM Router for load balancing -# - Use negative IDs to distinguish global configs from user configs (NewLLMConfig in DB) -# - IDs should be unique and sequential (e.g., -1, -2, -3, etc.) +# - Use negative IDs to distinguish global models from BYOK/local DB models +# - IDs must be unique across chat, vision, and image generation configs +# - Suggested static ranges: chat -1..-999, vision -1001..-1999, image -2001..-2999 # - The 'api_key' field will not be exposed to users via API # - system_instructions: Custom prompt or empty string to use defaults # - use_default_system_instructions: true = use SURFSENSE_SYSTEM_INSTRUCTIONS when system_instructions is empty @@ -519,7 +524,7 @@ global_vision_llm_configs: # # # IMAGE GENERATION NOTES: -# - Image generation configs use the same ID scheme as LLM configs (negative for global) +# - Image generation configs use the shared GLOBAL ID namespace # - Supported models: dall-e-2, dall-e-3, gpt-image-1 (OpenAI), azure/* (Azure), # bedrock/* (AWS), vertex_ai/* (Google), recraft/* (Recraft), openrouter/* (OpenRouter) # - The router uses litellm.aimage_generation() for async image generation @@ -527,7 +532,7 @@ global_vision_llm_configs: # TPM (tokens per minute) does not apply since image APIs are billed/rate-limited per request, not per token. # # VISION LLM NOTES: -# - Vision configs use the same ID scheme (negative for global, positive for user DB) +# - Vision configs use the shared GLOBAL ID namespace # - Only use vision-capable models (GPT-4o, Gemini, Claude 3, etc.) # - Lower temperature (0.3) is recommended for accurate screenshot analysis # - Lower max_tokens (1000) is sufficient since autocomplete produces short suggestions diff --git a/surfsense_backend/app/schemas/new_llm_config.py b/surfsense_backend/app/schemas/new_llm_config.py index 716aa0457..2f04a9e66 100644 --- a/surfsense_backend/app/schemas/new_llm_config.py +++ b/surfsense_backend/app/schemas/new_llm_config.py @@ -229,7 +229,7 @@ class LLMPreferencesRead(BaseModel): description="ID of the vision LLM config to use for vision/screenshot analysis", ) agent_llm: dict[str, Any] | None = Field( - None, description="Full config for agent LLM" + None, description="Full config for chat model" ) image_generation_config: dict[str, Any] | None = Field( None, description="Full config for image generation" diff --git a/surfsense_backend/app/services/global_model_catalog.py b/surfsense_backend/app/services/global_model_catalog.py new file mode 100644 index 000000000..a43f58b9e --- /dev/null +++ b/surfsense_backend/app/services/global_model_catalog.py @@ -0,0 +1,142 @@ +"""Materialize server-owned GLOBAL YAML configs as virtual connections/models.""" + +from __future__ import annotations + +from typing import Any + +from app.services.model_resolver import native_connection_from_config + + +def _base_model(config: dict[str, Any]) -> str | None: + litellm_params = config.get("litellm_params") or {} + if isinstance(litellm_params, dict): + return litellm_params.get("base_model") + return None + + +def _connection_key(conn: dict[str, Any]) -> tuple[Any, ...]: + # Deliberately includes api_key because two operator-owned credentials for + # the same provider/base can have different quota/rate limits upstream. + return ( + conn.get("protocol"), + conn.get("native_provider"), + conn.get("base_url"), + conn.get("api_key"), + _freeze(conn.get("extra") or {}), + ) + + +def _freeze(value: Any) -> Any: + if isinstance(value, dict): + return tuple(sorted((key, _freeze(val)) for key, val in value.items())) + if isinstance(value, list): + return tuple(_freeze(item) for item in value) + return value + + +def _capabilities_for(role: str, config: dict[str, Any]) -> dict[str, bool]: + return { + "chat": role == "chat", + "vision": role == "vision" or bool(config.get("supports_image_input")), + "image_gen": role == "image_gen", + "embedding": False, + "tools": bool(config.get("supports_tools", False)), + } + + +def _catalog_metadata(config: dict[str, Any]) -> dict[str, Any]: + return { + "billing_tier": config.get("billing_tier", "free"), + "quota_reserve_tokens": config.get("quota_reserve_tokens"), + "rpm": config.get("rpm"), + "tpm": config.get("tpm"), + "anonymous_enabled": config.get("anonymous_enabled", False), + "seo_enabled": config.get("seo_enabled", False), + "seo_slug": config.get("seo_slug"), + "input_cost_per_token": (config.get("litellm_params") or {}).get( + "input_cost_per_token" + ) + if isinstance(config.get("litellm_params"), dict) + else None, + "output_cost_per_token": (config.get("litellm_params") or {}).get( + "output_cost_per_token" + ) + if isinstance(config.get("litellm_params"), dict) + else None, + "is_planner": config.get("is_planner", False), + "base_model": _base_model(config), + "router_pool_eligible": config.get("router_pool_eligible", True), + } + + +def materialize_global_model_catalog( + *, + chat_configs: list[dict[str, Any]], + vision_configs: list[dict[str, Any]], + image_configs: list[dict[str, Any]], +) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]: + connections: list[dict[str, Any]] = [] + models: list[dict[str, Any]] = [] + connection_id_by_key: dict[tuple[Any, ...], int] = {} + next_connection_id = -1 + + def add_config(config: dict[str, Any], role: str) -> None: + nonlocal next_connection_id + if not config.get("id") or not config.get("model_name"): + return + conn = native_connection_from_config(config) + conn["scope"] = "GLOBAL" + conn["enabled"] = True + conn["last_status"] = "OK" + key = _connection_key(conn) + connection_id = connection_id_by_key.get(key) + if connection_id is None: + connection_id = next_connection_id + next_connection_id -= 1 + connection_id_by_key[key] = connection_id + connections.append( + { + "id": connection_id, + **conn, + } + ) + + model_id = int(config["id"]) + models.append( + { + "id": model_id, + "connection_id": connection_id, + "model_id": config["model_name"], + "display_name": config.get("name") or config["model_name"], + "source": "MANUAL", + "capabilities": _capabilities_for(role, config), + "capabilities_declared": _capabilities_for(role, config), + "capabilities_verified": _capabilities_for(role, config), + "capabilities_override": {}, + "embedding_dimension": None, + "enabled": True, + "billing_tier": config.get("billing_tier", "free"), + "catalog": _catalog_metadata(config), + "role": role, + } + ) + + for cfg in chat_configs: + if cfg.get("is_auto_mode"): + continue + add_config(cfg, "chat") + for cfg in vision_configs: + if cfg.get("is_auto_mode"): + continue + add_config(cfg, "vision") + for cfg in image_configs: + if cfg.get("is_auto_mode"): + continue + add_config(cfg, "image_gen") + + # Each virtual connection is server-only. Callers that serialize these + # must strip api_key before returning data to clients. + return connections, models + + +__all__ = ["materialize_global_model_catalog"] diff --git a/surfsense_backend/app/services/model_resolver.py b/surfsense_backend/app/services/model_resolver.py new file mode 100644 index 000000000..ec485a5ae --- /dev/null +++ b/surfsense_backend/app/services/model_resolver.py @@ -0,0 +1,152 @@ +"""Single model-to-LiteLLM resolver. + +All chat, vision, image-generation, validation, and Auto routing paths should +turn a Connection + Model into LiteLLM input through this module. +""" + +from __future__ import annotations + +from collections.abc import Mapping +from typing import TYPE_CHECKING, Any + +from app.services.provider_api_base import resolve_api_base + +if TYPE_CHECKING: + from app.db import Connection + +PROTOCOL_OLLAMA = "OLLAMA" +PROTOCOL_OPENAI_COMPATIBLE = "OPENAI_COMPATIBLE" +PROTOCOL_NATIVE = "NATIVE" + +NATIVE_PROVIDER_PREFIX: dict[str, str] = { + "OPENAI": "openai", + "ANTHROPIC": "anthropic", + "GROQ": "groq", + "COHERE": "cohere", + "GOOGLE": "gemini", + "MISTRAL": "mistral", + "AZURE_OPENAI": "azure", + "AZURE": "azure", + "OPENROUTER": "openrouter", + "COMETAPI": "cometapi", + "XAI": "xai", + "BEDROCK": "bedrock", + "AWS_BEDROCK": "bedrock", + "VERTEX_AI": "vertex_ai", + "TOGETHER_AI": "together_ai", + "FIREWORKS_AI": "fireworks_ai", + "DEEPSEEK": "openai", + "ALIBABA_QWEN": "openai", + "MOONSHOT": "openai", + "ZHIPU": "openai", + "GITHUB_MODELS": "github", + "REPLICATE": "replicate", + "PERPLEXITY": "perplexity", + "ANYSCALE": "anyscale", + "DEEPINFRA": "deepinfra", + "CEREBRAS": "cerebras", + "SAMBANOVA": "sambanova", + "AI21": "ai21", + "CLOUDFLARE": "cloudflare", + "DATABRICKS": "databricks", + "HUGGINGFACE": "huggingface", + "MINIMAX": "openai", + "RECRAFT": "recraft", + "XINFERENCE": "xinference", + "NSCALE": "nscale", + "CUSTOM": "custom", +} + + +def ensure_v1(base_url: str | None) -> str | None: + if not base_url: + return None + stripped = base_url.rstrip("/") + if stripped.endswith("/v1"): + return stripped + return f"{stripped}/v1" + + +def _conn_value(conn: Connection | Mapping[str, Any], key: str) -> Any: + if isinstance(conn, Mapping): + return conn.get(key) + return getattr(conn, key) + + +def _protocol_value(protocol: Any) -> str: + return getattr(protocol, "value", str(protocol)) + + +def to_litellm( + conn: Connection | Mapping[str, Any], + model_id: str, +) -> tuple[str, dict[str, Any]]: + """Return ``(model_string, litellm_kwargs)`` for any model role.""" + protocol = _protocol_value(_conn_value(conn, "protocol")) + base_url = _conn_value(conn, "base_url") + api_key = _conn_value(conn, "api_key") + native_provider = _conn_value(conn, "native_provider") + extra = _conn_value(conn, "extra") or {} + + kwargs: dict[str, Any] = {} + if api_key: + kwargs["api_key"] = api_key + + if protocol == PROTOCOL_OLLAMA: + model_string = f"ollama_chat/{model_id}" + if base_url: + kwargs["api_base"] = base_url.rstrip("/") + elif protocol == PROTOCOL_OPENAI_COMPATIBLE: + model_string = f"openai/{model_id}" + api_base = ensure_v1(base_url) + if api_base: + kwargs["api_base"] = api_base + else: + provider_key = (native_provider or "").upper() + prefix = NATIVE_PROVIDER_PREFIX.get(provider_key, provider_key.lower()) + if prefix == "custom": + custom_provider = extra.get("custom_provider") or native_provider + model_string = f"{custom_provider}/{model_id}" if custom_provider else model_id + else: + model_string = f"{prefix}/{model_id}" + + api_base = resolve_api_base( + provider=provider_key, + provider_prefix=prefix, + config_api_base=base_url, + ) + if api_base: + kwargs["api_base"] = api_base + + if api_version := extra.get("api_version"): + kwargs["api_version"] = api_version + kwargs.update(extra.get("litellm_params", {})) + kwargs.update(extra.get("kwargs", {})) + return model_string, kwargs + + +def native_connection_from_config(config: Mapping[str, Any]) -> dict[str, Any]: + """Build an in-memory NATIVE connection mapping from a legacy/global config.""" + provider = str(config.get("provider") or config.get("custom_provider") or "CUSTOM") + extra: dict[str, Any] = { + "litellm_params": config.get("litellm_params") or {}, + } + if config.get("api_version"): + extra["api_version"] = config.get("api_version") + if config.get("custom_provider"): + extra["custom_provider"] = config.get("custom_provider") + return { + "protocol": PROTOCOL_NATIVE, + "native_provider": provider, + "base_url": config.get("api_base") or None, + "api_key": config.get("api_key") or None, + "extra": extra, + } + + +__all__ = [ + "NATIVE_PROVIDER_PREFIX", + "ensure_v1", + "native_connection_from_config", + "to_litellm", +] diff --git a/surfsense_backend/app/services/provider_capabilities.py b/surfsense_backend/app/services/provider_capabilities.py index f094c9954..9e1433214 100644 --- a/surfsense_backend/app/services/provider_capabilities.py +++ b/surfsense_backend/app/services/provider_capabilities.py @@ -46,6 +46,8 @@ from collections.abc import Iterable import litellm +from app.services.model_resolver import NATIVE_PROVIDER_PREFIX + logger = logging.getLogger(__name__) @@ -58,40 +60,7 @@ logger = logging.getLogger(__name__) # map there directly would re-introduce the # ``app.config -> ... -> deliverables/tools/generate_image -> # app.config`` cycle that prompted the move. -_PROVIDER_PREFIX_MAP: dict[str, str] = { - "OPENAI": "openai", - "ANTHROPIC": "anthropic", - "GROQ": "groq", - "COHERE": "cohere", - "GOOGLE": "gemini", - "OLLAMA": "ollama_chat", - "MISTRAL": "mistral", - "AZURE_OPENAI": "azure", - "OPENROUTER": "openrouter", - "XAI": "xai", - "BEDROCK": "bedrock", - "VERTEX_AI": "vertex_ai", - "TOGETHER_AI": "together_ai", - "FIREWORKS_AI": "fireworks_ai", - "DEEPSEEK": "openai", - "ALIBABA_QWEN": "openai", - "MOONSHOT": "openai", - "ZHIPU": "openai", - "GITHUB_MODELS": "github", - "REPLICATE": "replicate", - "PERPLEXITY": "perplexity", - "ANYSCALE": "anyscale", - "DEEPINFRA": "deepinfra", - "CEREBRAS": "cerebras", - "SAMBANOVA": "sambanova", - "AI21": "ai21", - "CLOUDFLARE": "cloudflare", - "DATABRICKS": "databricks", - "COMETAPI": "cometapi", - "HUGGINGFACE": "huggingface", - "MINIMAX": "openai", - "CUSTOM": "custom", -} +_PROVIDER_PREFIX_MAP = NATIVE_PROVIDER_PREFIX def _candidate_model_strings( From 62ff97c830f167c3810a7f50643f6cb779f98aad Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Wed, 10 Jun 2026 21:48:23 +0530 Subject: [PATCH 03/59] refactor(llm): route calls through resolved models --- .../app/services/billable_calls.py | 4 +- .../app/services/image_gen_router_service.py | 53 +-- .../app/services/llm_router_service.py | 88 +--- surfsense_backend/app/services/llm_service.py | 434 +++++++----------- .../app/services/vision_llm_router_service.py | 53 +-- 5 files changed, 209 insertions(+), 423 deletions(-) diff --git a/surfsense_backend/app/services/billable_calls.py b/surfsense_backend/app/services/billable_calls.py index 92ccd6a78..356195f6a 100644 --- a/surfsense_backend/app/services/billable_calls.py +++ b/surfsense_backend/app/services/billable_calls.py @@ -450,10 +450,10 @@ async def _resolve_agent_billing_for_search_space( thread_id: int | None = None, ) -> tuple[UUID, str, str]: """Resolve ``(owner_user_id, billing_tier, base_model)`` for the search-space - agent LLM. + chat model. Used by Celery tasks (podcast generation, video presentation) to bill the - search-space owner's premium credit pool when the agent LLM is premium. + search-space owner's premium credit pool when the chat model is premium. Resolution rules mirror chat at ``stream_new_chat.py:2294-2351``: diff --git a/surfsense_backend/app/services/image_gen_router_service.py b/surfsense_backend/app/services/image_gen_router_service.py index b4de2a0bf..0b03f5c6d 100644 --- a/surfsense_backend/app/services/image_gen_router_service.py +++ b/surfsense_backend/app/services/image_gen_router_service.py @@ -20,7 +20,11 @@ from typing import Any from litellm import Router from litellm.utils import ImageResponse -from app.services.provider_api_base import resolve_api_base +from app.services.model_resolver import ( + NATIVE_PROVIDER_PREFIX, + native_connection_from_config, + to_litellm, +) logger = logging.getLogger(__name__) @@ -30,17 +34,7 @@ IMAGE_GEN_AUTO_MODE_ID = 0 # Provider mapping for LiteLLM model string construction. # Only includes providers that support image generation. # See: https://docs.litellm.ai/docs/image_generation#supported-providers -IMAGE_GEN_PROVIDER_MAP = { - "OPENAI": "openai", - "AZURE_OPENAI": "azure", - "GOOGLE": "gemini", # Google AI Studio - "VERTEX_AI": "vertex_ai", - "BEDROCK": "bedrock", # AWS Bedrock - "RECRAFT": "recraft", - "OPENROUTER": "openrouter", - "XINFERENCE": "xinference", - "NSCALE": "nscale", -} +IMAGE_GEN_PROVIDER_MAP = NATIVE_PROVIDER_PREFIX class ImageGenRouterService: @@ -153,38 +147,11 @@ class ImageGenRouterService: if not config.get("model_name") or not config.get("api_key"): return None - # Build model string - provider = config.get("provider", "").upper() - if config.get("custom_provider"): - provider_prefix = config["custom_provider"] - else: - provider_prefix = IMAGE_GEN_PROVIDER_MAP.get(provider, provider.lower()) - model_string = f"{provider_prefix}/{config['model_name']}" - - # Build litellm params - litellm_params: dict[str, Any] = { - "model": model_string, - "api_key": config.get("api_key"), - } - - # Resolve ``api_base`` so deployments don't silently inherit - # ``AZURE_OPENAI_ENDPOINT`` / ``OPENAI_API_BASE`` and 404 against - # the wrong provider (see ``provider_api_base`` docstring). - api_base = resolve_api_base( - provider=provider, - provider_prefix=provider_prefix, - config_api_base=config.get("api_base"), + model_string, resolved_kwargs = to_litellm( + native_connection_from_config(config), + config["model_name"], ) - if api_base: - litellm_params["api_base"] = api_base - - # Add api_version (required for Azure) - if config.get("api_version"): - litellm_params["api_version"] = config["api_version"] - - # Add any additional litellm parameters - if config.get("litellm_params"): - litellm_params.update(config["litellm_params"]) + litellm_params: dict[str, Any] = {"model": model_string, **resolved_kwargs} # All configs use same alias "auto" for unified routing deployment: dict[str, Any] = { diff --git a/surfsense_backend/app/services/llm_router_service.py b/surfsense_backend/app/services/llm_router_service.py index d220aa346..69feb30eb 100644 --- a/surfsense_backend/app/services/llm_router_service.py +++ b/surfsense_backend/app/services/llm_router_service.py @@ -30,6 +30,11 @@ from litellm.exceptions import ( ) from pydantic import Field +from app.services.model_resolver import ( + NATIVE_PROVIDER_PREFIX, + native_connection_from_config, + to_litellm, +) from app.utils.perf import get_perf_logger litellm.json_logs = False @@ -96,52 +101,8 @@ def _sanitize_content(content: Any) -> Any: # Special ID for Auto mode - uses router for load balancing AUTO_MODE_ID = 0 -# Provider mapping for LiteLLM model string construction -PROVIDER_MAP = { - "OPENAI": "openai", - "ANTHROPIC": "anthropic", - "GROQ": "groq", - "COHERE": "cohere", - "GOOGLE": "gemini", - "OLLAMA": "ollama_chat", - "MISTRAL": "mistral", - "AZURE_OPENAI": "azure", - "OPENROUTER": "openrouter", - "COMETAPI": "cometapi", - "XAI": "xai", - "BEDROCK": "bedrock", - "AWS_BEDROCK": "bedrock", # Legacy support - "VERTEX_AI": "vertex_ai", - "TOGETHER_AI": "together_ai", - "FIREWORKS_AI": "fireworks_ai", - "REPLICATE": "replicate", - "PERPLEXITY": "perplexity", - "ANYSCALE": "anyscale", - "DEEPINFRA": "deepinfra", - "CEREBRAS": "cerebras", - "SAMBANOVA": "sambanova", - "AI21": "ai21", - "CLOUDFLARE": "cloudflare", - "DATABRICKS": "databricks", - "DEEPSEEK": "openai", - "ALIBABA_QWEN": "openai", - "MOONSHOT": "openai", - "ZHIPU": "openai", - "GITHUB_MODELS": "github", - "HUGGINGFACE": "huggingface", - "MINIMAX": "openai", - "CUSTOM": "custom", -} - - -# ``PROVIDER_DEFAULT_API_BASE`` and ``PROVIDER_KEY_DEFAULT_API_BASE`` were -# hoisted to ``app.services.provider_api_base`` so vision and image-gen -# call sites can share the exact same defense (OpenRouter / Groq / etc. -# 404-ing against an inherited Azure endpoint). Re-exported here for -# backward compatibility with any external import. -from app.services.provider_api_base import ( # noqa: E402 - resolve_api_base, -) +# Historical export kept for callers that still import PROVIDER_MAP. +PROVIDER_MAP = NATIVE_PROVIDER_PREFIX class LLMRouterService: @@ -420,38 +381,11 @@ class LLMRouterService: if not config.get("model_name") or not config.get("api_key"): return None - # Build model string - provider = config.get("provider", "").upper() - if config.get("custom_provider"): - provider_prefix = config["custom_provider"] - model_string = f"{provider_prefix}/{config['model_name']}" - else: - provider_prefix = PROVIDER_MAP.get(provider, provider.lower()) - model_string = f"{provider_prefix}/{config['model_name']}" - - # Build litellm params - litellm_params = { - "model": model_string, - "api_key": config.get("api_key"), - } - - # Resolve ``api_base``. Config value wins; otherwise apply a - # provider-aware default so the deployment does not silently - # inherit unrelated env vars (e.g. ``AZURE_API_BASE``) and route - # requests to the wrong endpoint. See ``provider_api_base`` - # docstring for the motivating bug (OpenRouter models 404-ing - # against an Azure endpoint). - api_base = resolve_api_base( - provider=provider, - provider_prefix=provider_prefix, - config_api_base=config.get("api_base"), + model_string, resolved_kwargs = to_litellm( + native_connection_from_config(config), + config["model_name"], ) - if api_base: - litellm_params["api_base"] = api_base - - # Add any additional litellm parameters - if config.get("litellm_params"): - litellm_params.update(config["litellm_params"]) + litellm_params = {"model": model_string, **resolved_kwargs} # Extract rate limits if provided deployment = { diff --git a/surfsense_backend/app/services/llm_service.py b/surfsense_backend/app/services/llm_service.py index 7061a826f..75451d01f 100644 --- a/surfsense_backend/app/services/llm_service.py +++ b/surfsense_backend/app/services/llm_service.py @@ -6,9 +6,10 @@ from langchain_core.messages import HumanMessage from langchain_litellm import ChatLiteLLM from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.future import select +from sqlalchemy.orm import selectinload from app.config import config -from app.db import NewLLMConfig, SearchSpace +from app.db import Model, SearchSpace from app.services.llm_router_service import ( AUTO_MODE_ID, ChatLiteLLMRouter, @@ -16,7 +17,7 @@ from app.services.llm_router_service import ( get_auto_mode_llm, is_auto_mode, ) -from app.services.provider_api_base import resolve_api_base +from app.services.model_resolver import native_connection_from_config, to_litellm from app.services.token_tracking_service import token_tracker # Configure litellm to automatically drop unsupported parameters @@ -66,6 +67,29 @@ def _is_interactive_auth_provider( return False +def _legacy_config_connection( + *, + provider: str, + model_name: str, + api_key: str | None, + api_base: str | None, + custom_provider: str | None = None, + litellm_params: dict | None = None, + api_version: str | None = None, +) -> tuple[str, dict]: + cfg = { + "provider": provider, + "model_name": model_name, + "api_key": api_key, + "api_base": api_base, + "custom_provider": custom_provider, + "api_version": api_version, + "litellm_params": litellm_params or {}, + } + conn = native_connection_from_config(cfg) + return to_litellm(conn, model_name) + + class LLMRole: AGENT = "agent" # For agent/chat operations @@ -102,6 +126,60 @@ def get_global_llm_config(llm_config_id: int) -> dict | None: return None +def get_global_model(model_id: int) -> dict | None: + return next((m for m in config.GLOBAL_MODELS if m.get("id") == model_id), None) + + +def get_global_connection(connection_id: int) -> dict | None: + return next( + (c for c in config.GLOBAL_CONNECTIONS if c.get("id") == connection_id), + None, + ) + + +def _has_capability(model: dict | Model, capability: str) -> bool: + caps = ( + model.get("capabilities", {}) + if isinstance(model, dict) + else model.capabilities or {} + ) + return bool(caps.get(capability)) + + +def _chat_litellm_from_resolved( + *, + conn: dict | object, + model_id: str, + disable_streaming: bool = False, +) -> tuple[str, dict]: + model_string, resolved_kwargs = to_litellm(conn, model_id) + litellm_kwargs = {"model": model_string, **resolved_kwargs} + if disable_streaming: + litellm_kwargs["disable_streaming"] = True + return model_string, litellm_kwargs + + +async def _get_db_model( + session: AsyncSession, + model_id: int, + search_space: SearchSpace, +) -> Model | None: + result = await session.execute( + select(Model) + .options(selectinload(Model.connection)) + .where(Model.id == model_id, Model.enabled.is_(True)) + ) + model = result.scalars().first() + if not model or not model.connection or not model.connection.enabled: + return None + conn = model.connection + if conn.search_space_id and conn.search_space_id != search_space.id: + return None + if conn.user_id and conn.user_id != search_space.user_id: + return None + return model + + async def validate_llm_config( provider: str, model_name: str, @@ -146,62 +224,15 @@ async def validate_llm_config( return False, msg try: - # Build the model string for litellm - if custom_provider: - model_string = f"{custom_provider}/{model_name}" - else: - # Map provider enum to litellm format - provider_map = { - "OPENAI": "openai", - "ANTHROPIC": "anthropic", - "GROQ": "groq", - "COHERE": "cohere", - "GOOGLE": "gemini", - "OLLAMA": "ollama_chat", - "MISTRAL": "mistral", - "AZURE_OPENAI": "azure", - "OPENROUTER": "openrouter", - "COMETAPI": "cometapi", - "XAI": "xai", - "BEDROCK": "bedrock", - "AWS_BEDROCK": "bedrock", # Legacy support (backward compatibility) - "VERTEX_AI": "vertex_ai", - "TOGETHER_AI": "together_ai", - "FIREWORKS_AI": "fireworks_ai", - "REPLICATE": "replicate", - "PERPLEXITY": "perplexity", - "ANYSCALE": "anyscale", - "DEEPINFRA": "deepinfra", - "CEREBRAS": "cerebras", - "SAMBANOVA": "sambanova", - "AI21": "ai21", - "CLOUDFLARE": "cloudflare", - "DATABRICKS": "databricks", - # Chinese LLM providers - "DEEPSEEK": "openai", - "ALIBABA_QWEN": "openai", - "MOONSHOT": "openai", - "ZHIPU": "openai", # GLM needs special handling - "MINIMAX": "openai", - "GITHUB_MODELS": "github", - } - provider_prefix = provider_map.get(provider, provider.lower()) - model_string = f"{provider_prefix}/{model_name}" - - # Create ChatLiteLLM instance - litellm_kwargs = { - "model": model_string, - "api_key": api_key, - "timeout": 30, # Set a timeout for validation - } - - # Add optional parameters - if api_base: - litellm_kwargs["api_base"] = api_base - - # Add any additional litellm parameters - if litellm_params: - litellm_kwargs.update(litellm_params) + model_string, resolved_kwargs = _legacy_config_connection( + provider=provider, + model_name=model_name, + api_key=api_key, + api_base=api_base, + custom_provider=custom_provider, + litellm_params=litellm_params, + ) + litellm_kwargs = {"model": model_string, **resolved_kwargs, "timeout": 30} from app.agents.chat.runtime.llm_config import ( SanitizedChatLiteLLM, @@ -283,9 +314,9 @@ async def get_search_space_llm_instance( logger.error(f"Search space {search_space_id} not found") return None - # Get the appropriate LLM config ID based on role + # Get the appropriate model binding ID based on role if role == LLMRole.AGENT: - llm_config_id = search_space.agent_llm_id + llm_config_id = search_space.chat_model_id else: logger.error(f"Invalid LLM role: {role}") return None @@ -312,70 +343,26 @@ async def get_search_space_llm_instance( logger.error(f"Failed to create ChatLiteLLMRouter: {e}") return None - # Check if this is a global config (negative ID) + # Check if this is a global virtual model (negative ID) if llm_config_id < 0: - global_config = get_global_llm_config(llm_config_id) - if not global_config: - logger.error(f"Global LLM config {llm_config_id} not found") + global_model = get_global_model(llm_config_id) + if not global_model or not _has_capability(global_model, "chat"): + logger.error(f"Global chat model {llm_config_id} not found") + return None + global_connection = get_global_connection(global_model["connection_id"]) + if not global_connection: + logger.error( + "Global connection %s not found for model %s", + global_model["connection_id"], + llm_config_id, + ) return None - # Build model string for global config - if global_config.get("custom_provider"): - model_string = ( - f"{global_config['custom_provider']}/{global_config['model_name']}" - ) - else: - provider_map = { - "OPENAI": "openai", - "ANTHROPIC": "anthropic", - "GROQ": "groq", - "COHERE": "cohere", - "GOOGLE": "gemini", - "OLLAMA": "ollama_chat", - "MISTRAL": "mistral", - "AZURE_OPENAI": "azure", - "OPENROUTER": "openrouter", - "COMETAPI": "cometapi", - "XAI": "xai", - "BEDROCK": "bedrock", - "AWS_BEDROCK": "bedrock", - "VERTEX_AI": "vertex_ai", - "TOGETHER_AI": "together_ai", - "FIREWORKS_AI": "fireworks_ai", - "REPLICATE": "replicate", - "PERPLEXITY": "perplexity", - "ANYSCALE": "anyscale", - "DEEPINFRA": "deepinfra", - "CEREBRAS": "cerebras", - "SAMBANOVA": "sambanova", - "AI21": "ai21", - "CLOUDFLARE": "cloudflare", - "DATABRICKS": "databricks", - "DEEPSEEK": "openai", - "ALIBABA_QWEN": "openai", - "MOONSHOT": "openai", - "ZHIPU": "openai", - "MINIMAX": "openai", - } - provider_prefix = provider_map.get( - global_config["provider"], global_config["provider"].lower() - ) - model_string = f"{provider_prefix}/{global_config['model_name']}" - - # Create ChatLiteLLM instance from global config - litellm_kwargs = { - "model": model_string, - "api_key": global_config["api_key"], - } - - if global_config.get("api_base"): - litellm_kwargs["api_base"] = global_config["api_base"] - - if global_config.get("litellm_params"): - litellm_kwargs.update(global_config["litellm_params"]) - - if disable_streaming: - litellm_kwargs["disable_streaming"] = True + _, litellm_kwargs = _chat_litellm_from_resolved( + conn=global_connection, + model_id=global_model["model_id"], + disable_streaming=disable_streaming, + ) from app.agents.chat.runtime.llm_config import ( SanitizedChatLiteLLM, @@ -383,80 +370,18 @@ async def get_search_space_llm_instance( return SanitizedChatLiteLLM(**litellm_kwargs) - # Get the LLM configuration from database (NewLLMConfig) - result = await session.execute( - select(NewLLMConfig).where( - NewLLMConfig.id == llm_config_id, - NewLLMConfig.search_space_id == search_space_id, - ) - ) - llm_config = result.scalars().first() - - if not llm_config: + model = await _get_db_model(session, llm_config_id, search_space) + if not model or not _has_capability(model, "chat"): logger.error( - f"LLM config {llm_config_id} not found in search space {search_space_id}" + f"Chat model {llm_config_id} not found in search space {search_space_id}" ) return None - # Build the model string for litellm - if llm_config.custom_provider: - model_string = f"{llm_config.custom_provider}/{llm_config.model_name}" - else: - # Map provider enum to litellm format - provider_map = { - "OPENAI": "openai", - "ANTHROPIC": "anthropic", - "GROQ": "groq", - "COHERE": "cohere", - "GOOGLE": "gemini", - "OLLAMA": "ollama_chat", - "MISTRAL": "mistral", - "AZURE_OPENAI": "azure", - "OPENROUTER": "openrouter", - "COMETAPI": "cometapi", - "XAI": "xai", - "BEDROCK": "bedrock", - "AWS_BEDROCK": "bedrock", - "VERTEX_AI": "vertex_ai", - "TOGETHER_AI": "together_ai", - "FIREWORKS_AI": "fireworks_ai", - "REPLICATE": "replicate", - "PERPLEXITY": "perplexity", - "ANYSCALE": "anyscale", - "DEEPINFRA": "deepinfra", - "CEREBRAS": "cerebras", - "SAMBANOVA": "sambanova", - "AI21": "ai21", - "CLOUDFLARE": "cloudflare", - "DATABRICKS": "databricks", - "DEEPSEEK": "openai", - "ALIBABA_QWEN": "openai", - "MOONSHOT": "openai", - "ZHIPU": "openai", - "MINIMAX": "openai", - "GITHUB_MODELS": "github", - } - provider_prefix = provider_map.get( - llm_config.provider.value, llm_config.provider.value.lower() - ) - model_string = f"{provider_prefix}/{llm_config.model_name}" - - # Create ChatLiteLLM instance - litellm_kwargs = { - "model": model_string, - "api_key": llm_config.api_key, - } - - # Add optional parameters - if llm_config.api_base: - litellm_kwargs["api_base"] = llm_config.api_base - - # Add any additional litellm parameters - if llm_config.litellm_params: - litellm_kwargs.update(llm_config.litellm_params) - - if disable_streaming: - litellm_kwargs["disable_streaming"] = True + _, litellm_kwargs = _chat_litellm_from_resolved( + conn=model.connection, + model_id=model.model_id, + disable_streaming=disable_streaming, + ) from app.agents.chat.runtime.llm_config import ( SanitizedChatLiteLLM, @@ -474,7 +399,7 @@ async def get_search_space_llm_instance( async def get_agent_llm( session: AsyncSession, search_space_id: int, disable_streaming: bool = False ) -> ChatLiteLLM | ChatLiteLLMRouter | None: - """Get the search space's agent LLM instance for chat operations.""" + """Get the search space's chat model instance.""" return await get_search_space_llm_instance( session, search_space_id, @@ -488,22 +413,19 @@ async def get_vision_llm( ) -> ChatLiteLLM | ChatLiteLLMRouter | None: """Get the search space's vision LLM instance for screenshot analysis. - Resolves from the dedicated VisionLLMConfig system: + Resolves from the new connection/model role bindings: - Auto mode (ID 0): VisionLLMRouterService - - Global (negative ID): YAML configs - - DB (positive ID): VisionLLMConfig table + - Global (negative ID): virtual GLOBAL models from YAML + - DB (positive ID): Model + Connection tables Premium global configs are wrapped in :class:`QuotaCheckedVisionLLM` so each ``ainvoke`` debits the search-space owner's premium credit pool. User-owned BYOK configs and free global configs are returned unwrapped — they don't consume premium credit (issue M). """ - from app.db import VisionLLMConfig from app.services.quota_checked_vision_llm import QuotaCheckedVisionLLM from app.services.vision_llm_router_service import ( - VISION_PROVIDER_MAP, VisionLLMRouterService, - get_global_vision_llm_config, is_vision_auto_mode, ) @@ -516,13 +438,43 @@ async def get_vision_llm( logger.error(f"Search space {search_space_id} not found") return None - config_id = search_space.vision_llm_config_id + owner_user_id = search_space.user_id + + # Prefer the selected chat model when it is vision-capable. + chat_model_id = search_space.chat_model_id + if chat_model_id and chat_model_id != AUTO_MODE_ID: + if chat_model_id < 0: + chat_model = get_global_model(chat_model_id) + if chat_model and _has_capability(chat_model, "vision"): + global_connection = get_global_connection(chat_model["connection_id"]) + if global_connection: + model_string, litellm_kwargs = _chat_litellm_from_resolved( + conn=global_connection, + model_id=chat_model["model_id"], + ) + from app.agents.chat.runtime.llm_config import ( + SanitizedChatLiteLLM, + ) + + return SanitizedChatLiteLLM(**litellm_kwargs) + else: + chat_model = await _get_db_model(session, chat_model_id, search_space) + if chat_model and _has_capability(chat_model, "vision"): + _, litellm_kwargs = _chat_litellm_from_resolved( + conn=chat_model.connection, + model_id=chat_model.model_id, + ) + from app.agents.chat.runtime.llm_config import ( + SanitizedChatLiteLLM, + ) + + return SanitizedChatLiteLLM(**litellm_kwargs) + + config_id = search_space.vision_model_id if config_id is None: logger.error(f"No vision LLM configured for search space {search_space_id}") return None - owner_user_id = search_space.user_id - if is_vision_auto_mode(config_id): if not VisionLLMRouterService.is_initialized(): logger.error( @@ -546,34 +498,24 @@ async def get_vision_llm( return None if config_id < 0: - global_cfg = get_global_vision_llm_config(config_id) - if not global_cfg: - logger.error(f"Global vision LLM config {config_id} not found") + global_model = get_global_model(config_id) + if not global_model or not _has_capability(global_model, "vision"): + logger.error(f"Global vision model {config_id} not found") return None - if global_cfg.get("custom_provider"): - provider_prefix = global_cfg["custom_provider"] - model_string = f"{provider_prefix}/{global_cfg['model_name']}" - else: - provider_prefix = VISION_PROVIDER_MAP.get( - global_cfg["provider"].upper(), - global_cfg["provider"].lower(), + global_connection = get_global_connection(global_model["connection_id"]) + if not global_connection: + logger.error( + "Global connection %s not found for model %s", + global_model["connection_id"], + config_id, ) - model_string = f"{provider_prefix}/{global_cfg['model_name']}" + return None - litellm_kwargs = { - "model": model_string, - "api_key": global_cfg["api_key"], - } - api_base = resolve_api_base( - provider=global_cfg.get("provider"), - provider_prefix=provider_prefix, - config_api_base=global_cfg.get("api_base"), + model_string, litellm_kwargs = _chat_litellm_from_resolved( + conn=global_connection, + model_id=global_model["model_id"], ) - if api_base: - litellm_kwargs["api_base"] = api_base - if global_cfg.get("litellm_params"): - litellm_kwargs.update(global_cfg["litellm_params"]) from app.agents.chat.runtime.llm_config import ( SanitizedChatLiteLLM, @@ -581,7 +523,7 @@ async def get_vision_llm( inner_llm = SanitizedChatLiteLLM(**litellm_kwargs) - billing_tier = str(global_cfg.get("billing_tier", "free")).lower() + billing_tier = str(global_model.get("billing_tier", "free")).lower() if billing_tier == "premium": return QuotaCheckedVisionLLM( inner_llm, @@ -589,47 +531,23 @@ async def get_vision_llm( search_space_id=search_space_id, billing_tier=billing_tier, base_model=model_string, - quota_reserve_tokens=global_cfg.get("quota_reserve_tokens"), + quota_reserve_tokens=global_model.get("catalog", {}).get( + "quota_reserve_tokens" + ), ) return inner_llm - # User-owned (positive ID) BYOK configs — always free. - result = await session.execute( - select(VisionLLMConfig).where( - VisionLLMConfig.id == config_id, - VisionLLMConfig.search_space_id == search_space_id, - ) - ) - vision_cfg = result.scalars().first() - if not vision_cfg: + model = await _get_db_model(session, config_id, search_space) + if not model or not _has_capability(model, "vision"): logger.error( - f"Vision LLM config {config_id} not found in search space {search_space_id}" + f"Vision model {config_id} not found in search space {search_space_id}" ) return None - if vision_cfg.custom_provider: - provider_prefix = vision_cfg.custom_provider - model_string = f"{provider_prefix}/{vision_cfg.model_name}" - else: - provider_prefix = VISION_PROVIDER_MAP.get( - vision_cfg.provider.value.upper(), - vision_cfg.provider.value.lower(), - ) - model_string = f"{provider_prefix}/{vision_cfg.model_name}" - - litellm_kwargs = { - "model": model_string, - "api_key": vision_cfg.api_key, - } - api_base = resolve_api_base( - provider=vision_cfg.provider.value, - provider_prefix=provider_prefix, - config_api_base=vision_cfg.api_base, + _, litellm_kwargs = _chat_litellm_from_resolved( + conn=model.connection, + model_id=model.model_id, ) - if api_base: - litellm_kwargs["api_base"] = api_base - if vision_cfg.litellm_params: - litellm_kwargs.update(vision_cfg.litellm_params) from app.agents.chat.runtime.llm_config import ( SanitizedChatLiteLLM, diff --git a/surfsense_backend/app/services/vision_llm_router_service.py b/surfsense_backend/app/services/vision_llm_router_service.py index ed5de921c..0c7182ecf 100644 --- a/surfsense_backend/app/services/vision_llm_router_service.py +++ b/surfsense_backend/app/services/vision_llm_router_service.py @@ -3,29 +3,17 @@ from typing import Any from litellm import Router -from app.services.provider_api_base import resolve_api_base +from app.services.model_resolver import ( + NATIVE_PROVIDER_PREFIX, + native_connection_from_config, + to_litellm, +) logger = logging.getLogger(__name__) VISION_AUTO_MODE_ID = 0 -VISION_PROVIDER_MAP = { - "OPENAI": "openai", - "ANTHROPIC": "anthropic", - "GOOGLE": "gemini", - "AZURE_OPENAI": "azure", - "VERTEX_AI": "vertex_ai", - "BEDROCK": "bedrock", - "XAI": "xai", - "OPENROUTER": "openrouter", - "OLLAMA": "ollama_chat", - "GROQ": "groq", - "TOGETHER_AI": "together_ai", - "FIREWORKS_AI": "fireworks_ai", - "DEEPSEEK": "openai", - "MISTRAL": "mistral", - "CUSTOM": "custom", -} +VISION_PROVIDER_MAP = NATIVE_PROVIDER_PREFIX class VisionLLMRouterService: @@ -110,32 +98,11 @@ class VisionLLMRouterService: if not config.get("model_name") or not config.get("api_key"): return None - provider = config.get("provider", "").upper() - if config.get("custom_provider"): - provider_prefix = config["custom_provider"] - model_string = f"{provider_prefix}/{config['model_name']}" - else: - provider_prefix = VISION_PROVIDER_MAP.get(provider, provider.lower()) - model_string = f"{provider_prefix}/{config['model_name']}" - - litellm_params: dict[str, Any] = { - "model": model_string, - "api_key": config.get("api_key"), - } - - api_base = resolve_api_base( - provider=provider, - provider_prefix=provider_prefix, - config_api_base=config.get("api_base"), + model_string, resolved_kwargs = to_litellm( + native_connection_from_config(config), + config["model_name"], ) - if api_base: - litellm_params["api_base"] = api_base - - if config.get("api_version"): - litellm_params["api_version"] = config["api_version"] - - if config.get("litellm_params"): - litellm_params.update(config["litellm_params"]) + litellm_params: dict[str, Any] = {"model": model_string, **resolved_kwargs} deployment: dict[str, Any] = { "model_name": "auto", From 077016d6e444c2bedb6d94cb4be6566771556815 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Wed, 10 Jun 2026 21:48:37 +0530 Subject: [PATCH 04/59] refactor(images): use model connections for image generation --- .../deliverables/tools/generate_image.py | 104 +++++------------- .../builtins/deliverables/tools/index.py | 4 +- .../app/routes/image_generation_routes.py | 96 ++++------------ 3 files changed, 52 insertions(+), 152 deletions(-) diff --git a/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/deliverables/tools/generate_image.py b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/deliverables/tools/generate_image.py index 7bb4a7c24..dd980c51c 100644 --- a/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/deliverables/tools/generate_image.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/deliverables/tools/generate_image.py @@ -10,13 +10,14 @@ from langgraph.types import Command from litellm import aimage_generation from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import selectinload from app.agents.chat.multi_agent_chat.shared.receipts.command import with_receipt from app.agents.chat.multi_agent_chat.shared.receipts.receipt import make_receipt from app.config import config from app.db import ( ImageGeneration, - ImageGenerationConfig, + Model, SearchSpace, shielded_async_session, ) @@ -25,37 +26,11 @@ from app.services.image_gen_router_service import ( ImageGenRouterService, is_image_gen_auto_mode, ) -from app.services.provider_api_base import resolve_api_base +from app.services.model_resolver import native_connection_from_config, to_litellm from app.utils.signed_image_urls import generate_image_token logger = logging.getLogger(__name__) -# Provider mapping (same as routes) -_PROVIDER_MAP = { - "OPENAI": "openai", - "AZURE_OPENAI": "azure", - "GOOGLE": "gemini", - "VERTEX_AI": "vertex_ai", - "BEDROCK": "bedrock", - "RECRAFT": "recraft", - "OPENROUTER": "openrouter", - "XINFERENCE": "xinference", - "NSCALE": "nscale", -} - - -def _resolve_provider_prefix(provider: str, custom_provider: str | None) -> str: - if custom_provider: - return custom_provider - return _PROVIDER_MAP.get(provider.upper(), provider.lower()) - - -def _build_model_string( - provider: str, model_name: str, custom_provider: str | None -) -> str: - return f"{_resolve_provider_prefix(provider, custom_provider)}/{model_name}" - - def _get_global_image_gen_config(config_id: int) -> dict | None: """Get a global image gen config by negative ID.""" for cfg in config.GLOBAL_IMAGE_GEN_CONFIGS: @@ -67,13 +42,13 @@ def _get_global_image_gen_config(config_id: int) -> dict | None: def create_generate_image_tool( search_space_id: int, db_session: AsyncSession, - image_generation_config_id_override: int | None = None, + image_gen_model_id_override: int | None = None, ): """Create ``generate_image`` with bound search space; DB work uses a per-call session. - ``image_generation_config_id_override``: when set (automations running on a - captured model), use this config id instead of reading the search space's - live ``image_generation_config_id``. + ``image_gen_model_id_override``: when set (automations running on a + captured model), use this model id instead of reading the search space's + live ``image_gen_model_id``. """ del db_session # tool uses a fresh per-call session instead @@ -118,11 +93,11 @@ def create_generate_image_tool( # task's session is shared across every tool; without isolation, # autoflushes from a concurrent writer poison this tool too. async with shielded_async_session() as session: - if image_generation_config_id_override is not None: + if image_gen_model_id_override is not None: # Automation run: use the captured image model, insulated from # later search-space changes. No search-space read needed. config_id = ( - image_generation_config_id_override or IMAGE_GEN_AUTO_MODE_ID + image_gen_model_id_override or IMAGE_GEN_AUTO_MODE_ID ) else: result = await session.execute( @@ -136,7 +111,7 @@ def create_generate_image_tool( ) config_id = ( - search_space.image_generation_config_id + search_space.image_gen_model_id or IMAGE_GEN_AUTO_MODE_ID ) @@ -162,58 +137,35 @@ def create_generate_image_tool( err = f"Image generation config {config_id} not found" return _failed({"error": err}, error=err) - provider_prefix = _resolve_provider_prefix( - cfg.get("provider", ""), cfg.get("custom_provider") + model_string, resolved_kwargs = to_litellm( + native_connection_from_config(cfg), + cfg["model_name"], ) - model_string = f"{provider_prefix}/{cfg['model_name']}" - gen_kwargs["api_key"] = cfg.get("api_key") - # Defense-in-depth: an empty ``api_base`` must not fall - # through to LiteLLM's global ``api_base`` (e.g. Azure). - api_base = resolve_api_base( - provider=cfg.get("provider"), - provider_prefix=provider_prefix, - config_api_base=cfg.get("api_base"), - ) - if api_base: - gen_kwargs["api_base"] = api_base - if cfg.get("api_version"): - gen_kwargs["api_version"] = cfg["api_version"] - if cfg.get("litellm_params"): - gen_kwargs.update(cfg["litellm_params"]) + gen_kwargs.update(resolved_kwargs) response = await aimage_generation( prompt=prompt, model=model_string, **gen_kwargs ) else: - # Positive ID = user-created ImageGenerationConfig + # Positive ID = Model + Connection cfg_result = await session.execute( - select(ImageGenerationConfig).filter( - ImageGenerationConfig.id == config_id - ) + select(Model) + .options(selectinload(Model.connection)) + .filter(Model.id == config_id, Model.enabled.is_(True)) ) - db_cfg = cfg_result.scalars().first() - if not db_cfg: - err = f"Image generation config {config_id} not found" + db_model = cfg_result.scalars().first() + if not db_model or not db_model.connection or not db_model.connection.enabled: + err = f"Image generation model {config_id} not found" + return _failed({"error": err}, error=err) + if not (db_model.capabilities or {}).get("image_gen"): + err = f"Model {config_id} is not image-generation capable" return _failed({"error": err}, error=err) - provider_prefix = _resolve_provider_prefix( - db_cfg.provider.value, db_cfg.custom_provider + model_string, resolved_kwargs = to_litellm( + db_model.connection, + db_model.model_id, ) - model_string = f"{provider_prefix}/{db_cfg.model_name}" - gen_kwargs["api_key"] = db_cfg.api_key - # Defense-in-depth: an empty ``api_base`` must not fall - # through to LiteLLM's global ``api_base`` (e.g. Azure). - api_base = resolve_api_base( - provider=db_cfg.provider.value, - provider_prefix=provider_prefix, - config_api_base=db_cfg.api_base, - ) - if api_base: - gen_kwargs["api_base"] = api_base - if db_cfg.api_version: - gen_kwargs["api_version"] = db_cfg.api_version - if db_cfg.litellm_params: - gen_kwargs.update(db_cfg.litellm_params) + gen_kwargs.update(resolved_kwargs) response = await aimage_generation( prompt=prompt, model=model_string, **gen_kwargs diff --git a/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/deliverables/tools/index.py b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/deliverables/tools/index.py index b968c1701..8de95f2df 100644 --- a/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/deliverables/tools/index.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/deliverables/tools/index.py @@ -51,8 +51,6 @@ def load_tools( create_generate_image_tool( search_space_id=d["search_space_id"], db_session=d["db_session"], - image_generation_config_id_override=d.get( - "image_generation_config_id_override" - ), + image_gen_model_id_override=d.get("image_gen_model_id_override"), ), ] diff --git a/surfsense_backend/app/routes/image_generation_routes.py b/surfsense_backend/app/routes/image_generation_routes.py index 018234ad5..0de368d57 100644 --- a/surfsense_backend/app/routes/image_generation_routes.py +++ b/surfsense_backend/app/routes/image_generation_routes.py @@ -16,11 +16,13 @@ from litellm import aimage_generation from sqlalchemy import select from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import selectinload from app.config import config from app.db import ( ImageGeneration, ImageGenerationConfig, + Model, Permission, SearchSpace, SearchSpaceMembership, @@ -46,7 +48,7 @@ from app.services.image_gen_router_service import ( ImageGenRouterService, is_image_gen_auto_mode, ) -from app.services.provider_api_base import resolve_api_base +from app.services.model_resolver import native_connection_from_config, to_litellm from app.users import current_active_user from app.utils.rbac import check_permission from app.utils.signed_image_urls import verify_image_token @@ -54,22 +56,6 @@ from app.utils.signed_image_urls import verify_image_token router = APIRouter() logger = logging.getLogger(__name__) -# Provider mapping for building litellm model strings. -# Only includes providers that support image generation. -# See: https://docs.litellm.ai/docs/image_generation#supported-providers -_PROVIDER_MAP = { - "OPENAI": "openai", - "AZURE_OPENAI": "azure", - "GOOGLE": "gemini", # Google AI Studio - "VERTEX_AI": "vertex_ai", - "BEDROCK": "bedrock", # AWS Bedrock - "RECRAFT": "recraft", - "OPENROUTER": "openrouter", - "XINFERENCE": "xinference", - "NSCALE": "nscale", -} - - def _get_global_image_gen_config(config_id: int) -> dict | None: """Get a global image generation configuration by ID (negative IDs).""" if config_id == IMAGE_GEN_AUTO_MODE_ID: @@ -88,20 +74,6 @@ def _get_global_image_gen_config(config_id: int) -> dict | None: return None -def _resolve_provider_prefix(provider: str, custom_provider: str | None) -> str: - """Resolve the LiteLLM provider prefix used in model strings.""" - if custom_provider: - return custom_provider - return _PROVIDER_MAP.get(provider.upper(), provider.lower()) - - -def _build_model_string( - provider: str, model_name: str, custom_provider: str | None -) -> str: - """Build a litellm model string from provider + model_name.""" - return f"{_resolve_provider_prefix(provider, custom_provider)}/{model_name}" - - async def _resolve_billing_for_image_gen( session: AsyncSession, config_id: int | None, @@ -124,7 +96,7 @@ async def _resolve_billing_for_image_gen( """ resolved_id = config_id if resolved_id is None: - resolved_id = search_space.image_generation_config_id or IMAGE_GEN_AUTO_MODE_ID + resolved_id = search_space.image_gen_model_id or IMAGE_GEN_AUTO_MODE_ID if is_image_gen_auto_mode(resolved_id): return ("free", "auto", DEFAULT_IMAGE_RESERVE_MICROS) @@ -132,11 +104,7 @@ async def _resolve_billing_for_image_gen( if resolved_id < 0: cfg = _get_global_image_gen_config(resolved_id) or {} billing_tier = str(cfg.get("billing_tier", "free")).lower() - base_model = _build_model_string( - cfg.get("provider", ""), - cfg.get("model_name", ""), - cfg.get("custom_provider"), - ) + base_model, _ = to_litellm(native_connection_from_config(cfg), cfg.get("model_name", "")) reserve_micros = int( cfg.get("quota_reserve_micros") or DEFAULT_IMAGE_RESERVE_MICROS ) @@ -161,7 +129,7 @@ async def _execute_image_generation( """ config_id = image_gen.image_generation_config_id if config_id is None: - config_id = search_space.image_generation_config_id or IMAGE_GEN_AUTO_MODE_ID + config_id = search_space.image_gen_model_id or IMAGE_GEN_AUTO_MODE_ID image_gen.image_generation_config_id = config_id # Build kwargs @@ -192,22 +160,11 @@ async def _execute_image_generation( if not cfg: raise ValueError(f"Global image generation config {config_id} not found") - provider_prefix = _resolve_provider_prefix( - cfg.get("provider", ""), cfg.get("custom_provider") + model_string, resolved_kwargs = to_litellm( + native_connection_from_config(cfg), + cfg["model_name"], ) - model_string = f"{provider_prefix}/{cfg['model_name']}" - gen_kwargs["api_key"] = cfg.get("api_key") - api_base = resolve_api_base( - provider=cfg.get("provider"), - provider_prefix=provider_prefix, - config_api_base=cfg.get("api_base"), - ) - if api_base: - gen_kwargs["api_base"] = api_base - if cfg.get("api_version"): - gen_kwargs["api_version"] = cfg["api_version"] - if cfg.get("litellm_params"): - gen_kwargs.update(cfg["litellm_params"]) + gen_kwargs.update(resolved_kwargs) # User model override if image_gen.model: @@ -217,30 +174,23 @@ async def _execute_image_generation( prompt=image_gen.prompt, model=model_string, **gen_kwargs ) else: - # Positive ID = DB ImageGenerationConfig + # Positive ID = Model + Connection result = await session.execute( - select(ImageGenerationConfig).filter(ImageGenerationConfig.id == config_id) + select(Model) + .options(selectinload(Model.connection)) + .filter(Model.id == config_id, Model.enabled.is_(True)) ) - db_cfg = result.scalars().first() - if not db_cfg: - raise ValueError(f"Image generation config {config_id} not found") + db_model = result.scalars().first() + if not db_model or not db_model.connection or not db_model.connection.enabled: + raise ValueError(f"Image generation model {config_id} not found") + if not (db_model.capabilities or {}).get("image_gen"): + raise ValueError(f"Model {config_id} is not image-generation capable") - provider_prefix = _resolve_provider_prefix( - db_cfg.provider.value, db_cfg.custom_provider + model_string, resolved_kwargs = to_litellm( + db_model.connection, + db_model.model_id, ) - model_string = f"{provider_prefix}/{db_cfg.model_name}" - gen_kwargs["api_key"] = db_cfg.api_key - api_base = resolve_api_base( - provider=db_cfg.provider.value, - provider_prefix=provider_prefix, - config_api_base=db_cfg.api_base, - ) - if api_base: - gen_kwargs["api_base"] = api_base - if db_cfg.api_version: - gen_kwargs["api_version"] = db_cfg.api_version - if db_cfg.litellm_params: - gen_kwargs.update(db_cfg.litellm_params) + gen_kwargs.update(resolved_kwargs) # User model override if image_gen.model: From 18606fe3880cd4c53ba912620bdd91ad6e0bccd8 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Wed, 10 Jun 2026 21:48:53 +0530 Subject: [PATCH 05/59] feat(automations): add model connection policy support --- .../builtin/agent_task/dependencies.py | 30 +++---- .../actions/builtin/agent_task/invoke.py | 8 +- .../app/automations/actions/types.py | 6 +- .../app/automations/runtime/executor.py | 8 +- .../schemas/definition/envelope.py | 10 +-- .../app/automations/services/automation.py | 12 +-- .../app/automations/services/model_policy.py | 87 +++++++------------ 7 files changed, 67 insertions(+), 94 deletions(-) diff --git a/surfsense_backend/app/automations/actions/builtin/agent_task/dependencies.py b/surfsense_backend/app/automations/actions/builtin/agent_task/dependencies.py index 4ef8c52bf..c9584ae2a 100644 --- a/surfsense_backend/app/automations/actions/builtin/agent_task/dependencies.py +++ b/surfsense_backend/app/automations/actions/builtin/agent_task/dependencies.py @@ -39,31 +39,31 @@ async def build_dependencies( *, session: AsyncSession, search_space_id: int, - agent_llm_id: int | None = None, - image_generation_config_id: int | None = None, - vision_llm_config_id: int | None = None, + chat_model_id: int | None = None, + image_gen_model_id: int | None = None, + vision_model_id: int | None = None, ) -> AgentDependencies: """Load the LLM bundle, connector service, and a per-invoke in-memory checkpointer. - Resolves the agent LLM from the automation's *captured* model snapshot - (``agent_llm_id``) so runs are insulated from later chat/search-space model + Resolves the chat model from the automation's *captured* model snapshot + (``chat_model_id``) so runs are insulated from later chat/search-space model changes. The model policy is enforced here as a runtime backstop: a captured model that is no longer billable (e.g. a premium global config was removed) fails the run clearly instead of silently consuming a free model. - When ``agent_llm_id`` is ``None`` (no captured snapshot — defensive fallback), - fall back to the live search space's ``agent_llm_id`` and validate that. + When ``chat_model_id`` is ``None`` (no captured snapshot — defensive fallback), + fall back to the live search space's ``chat_model_id`` and validate that. """ - if agent_llm_id is not None: + if chat_model_id is not None: try: assert_models_billable( - agent_llm_id=agent_llm_id, - image_generation_config_id=image_generation_config_id, - vision_llm_config_id=vision_llm_config_id, + chat_model_id=chat_model_id, + image_gen_model_id=image_gen_model_id, + vision_model_id=vision_model_id, ) except AutomationModelPolicyError as exc: raise DependencyError(str(exc)) from exc - resolved_agent_llm_id = agent_llm_id or 0 + resolved_chat_model_id = chat_model_id or 0 else: search_space = await session.get(SearchSpace, search_space_id) if search_space is None: @@ -72,15 +72,15 @@ async def build_dependencies( assert_automation_models_billable(search_space) except AutomationModelPolicyError as exc: raise DependencyError(str(exc)) from exc - resolved_agent_llm_id = search_space.agent_llm_id or 0 + resolved_chat_model_id = search_space.chat_model_id or 0 llm, agent_config, err = await load_llm_bundle( session, - config_id=resolved_agent_llm_id, + config_id=resolved_chat_model_id, search_space_id=search_space_id, ) if err is not None or llm is None: - raise DependencyError(err or "failed to load agent LLM config") + raise DependencyError(err or "failed to load chat model config") connector_service, firecrawl_api_key = await setup_connector_and_firecrawl( session, search_space_id=search_space_id diff --git a/surfsense_backend/app/automations/actions/builtin/agent_task/invoke.py b/surfsense_backend/app/automations/actions/builtin/agent_task/invoke.py index aa96e4f6e..c3a35930d 100644 --- a/surfsense_backend/app/automations/actions/builtin/agent_task/invoke.py +++ b/surfsense_backend/app/automations/actions/builtin/agent_task/invoke.py @@ -150,9 +150,9 @@ async def run_agent_task( deps = await build_dependencies( session=agent_session, search_space_id=ctx.search_space_id, - agent_llm_id=ctx.agent_llm_id, - image_generation_config_id=ctx.image_generation_config_id, - vision_llm_config_id=ctx.vision_llm_config_id, + chat_model_id=ctx.chat_model_id, + image_gen_model_id=ctx.image_gen_model_id, + vision_model_id=ctx.vision_model_id, ) agent = await create_multi_agent_chat_deep_agent( @@ -167,7 +167,7 @@ async def run_agent_task( firecrawl_api_key=deps.firecrawl_api_key, thread_visibility=ChatVisibility.PRIVATE, mentioned_document_ids=mentioned_document_ids, - image_generation_config_id=ctx.image_generation_config_id, + image_gen_model_id=ctx.image_gen_model_id, ) agent_query, runtime_context = await _resolve_mention_context( diff --git a/surfsense_backend/app/automations/actions/types.py b/surfsense_backend/app/automations/actions/types.py index 453721a43..3ee427512 100644 --- a/surfsense_backend/app/automations/actions/types.py +++ b/surfsense_backend/app/automations/actions/types.py @@ -23,9 +23,9 @@ class ActionContext: # Captured model snapshot from the automation definition (``definition.models``), # resolved per run instead of the live search space. ``None`` falls back to the # search space's current prefs (defensive; should not happen post-capture). - agent_llm_id: int | None = None - image_generation_config_id: int | None = None - vision_llm_config_id: int | None = None + chat_model_id: int | None = None + image_gen_model_id: int | None = None + vision_model_id: int | None = None ActionHandler = Callable[[dict[str, Any]], Awaitable[Any]] diff --git a/surfsense_backend/app/automations/runtime/executor.py b/surfsense_backend/app/automations/runtime/executor.py index da249d8e5..bcdab3940 100644 --- a/surfsense_backend/app/automations/runtime/executor.py +++ b/surfsense_backend/app/automations/runtime/executor.py @@ -132,9 +132,7 @@ def _build_action_ctx( step_id=step.step_id, search_space_id=automation.search_space_id, creator_user_id=automation.created_by_user_id, - agent_llm_id=models.agent_llm_id if models else None, - image_generation_config_id=( - models.image_generation_config_id if models else None - ), - vision_llm_config_id=models.vision_llm_config_id if models else None, + chat_model_id=models.chat_model_id if models else None, + image_gen_model_id=models.image_gen_model_id if models else None, + vision_model_id=models.vision_model_id if models else None, ) diff --git a/surfsense_backend/app/automations/schemas/definition/envelope.py b/surfsense_backend/app/automations/schemas/definition/envelope.py index 7ca55b1ce..787534d4a 100644 --- a/surfsense_backend/app/automations/schemas/definition/envelope.py +++ b/surfsense_backend/app/automations/schemas/definition/envelope.py @@ -14,16 +14,16 @@ from .trigger_spec import TriggerSpec class AutomationModels(BaseModel): """Captured model profile for an automation. - Snapshotted from the search space's preferences at create time so runs are - insulated from later chat/search-space model changes. Config-id conventions + Snapshotted from the search space's model roles at create time so runs are + insulated from later chat/search-space model changes. Model-id conventions match the shared scheme (``0`` Auto, ``< 0`` global, ``> 0`` BYOK). """ model_config = ConfigDict(extra="forbid") - agent_llm_id: int = 0 - image_generation_config_id: int = 0 - vision_llm_config_id: int = 0 + chat_model_id: int = 0 + image_gen_model_id: int = 0 + vision_model_id: int = 0 class AutomationDefinition(BaseModel): diff --git a/surfsense_backend/app/automations/services/automation.py b/surfsense_backend/app/automations/services/automation.py index 4227161e2..1d371c35d 100644 --- a/surfsense_backend/app/automations/services/automation.py +++ b/surfsense_backend/app/automations/services/automation.py @@ -57,9 +57,9 @@ class AutomationService: else: search_space = await self._assert_models_billable(payload.search_space_id) payload.definition.models = AutomationModels( - agent_llm_id=search_space.agent_llm_id or 0, - image_generation_config_id=search_space.image_generation_config_id or 0, - vision_llm_config_id=search_space.vision_llm_config_id or 0, + chat_model_id=search_space.chat_model_id or 0, + image_gen_model_id=search_space.image_gen_model_id or 0, + vision_model_id=search_space.vision_model_id or 0, ) automation = Automation( @@ -225,9 +225,9 @@ class AutomationService: """ try: assert_models_billable( - agent_llm_id=models.agent_llm_id, - image_generation_config_id=models.image_generation_config_id, - vision_llm_config_id=models.vision_llm_config_id, + chat_model_id=models.chat_model_id, + image_gen_model_id=models.image_gen_model_id, + vision_model_id=models.vision_model_id, ) except AutomationModelPolicyError as exc: raise HTTPException(status_code=422, detail=str(exc)) from exc diff --git a/surfsense_backend/app/automations/services/model_policy.py b/surfsense_backend/app/automations/services/model_policy.py index 7e3e46b61..b160fc78d 100644 --- a/surfsense_backend/app/automations/services/model_policy.py +++ b/surfsense_backend/app/automations/services/model_policy.py @@ -24,70 +24,45 @@ from typing import TYPE_CHECKING, Literal if TYPE_CHECKING: from app.db import SearchSpace -ModelKind = Literal["llm", "image", "vision"] +ModelKind = Literal["chat", "image", "vision"] _KIND_LABEL: dict[ModelKind, str] = { - "llm": "agent LLM", + "chat": "chat model", "image": "image generation model", "vision": "vision model", } -def _is_premium_global(kind: ModelKind, config_id: int) -> bool: - """Return True if a negative (global) config id is a premium tier model.""" +def _is_premium_global(model_id: int) -> bool: + """Return True if a negative (global) model id is a premium tier model.""" from app.config import config as app_config - cfg: dict | None = None - if kind == "llm": - from app.agents.chat.runtime.llm_config import ( - load_global_llm_config_by_id, - ) - - cfg = load_global_llm_config_by_id(config_id) - elif kind == "image": - cfg = next( - ( - c - for c in app_config.GLOBAL_IMAGE_GEN_CONFIGS - if c.get("id") == config_id - ), - None, - ) - else: # vision - cfg = next( - ( - c - for c in app_config.GLOBAL_VISION_LLM_CONFIGS - if c.get("id") == config_id - ), - None, - ) - - if not cfg: + model = next((m for m in app_config.GLOBAL_MODELS if m.get("id") == model_id), None) + if not model: return False - return str(cfg.get("billing_tier", "free")).lower() == "premium" + return str(model.get("billing_tier", "free")).lower() == "premium" -def _classify(kind: ModelKind, config_id: int | None) -> tuple[bool, str]: - """Classify a resolved config id as allowed or blocked. +def _classify(kind: ModelKind, model_id: int | None) -> tuple[bool, str]: + """Classify a resolved model id as allowed or blocked. Returns ``(allowed, reason)``; ``reason`` is empty when allowed. """ label = _KIND_LABEL[kind] - if config_id is None or config_id == 0: + if model_id is None or model_id == 0: return ( False, f"The {label} is set to Auto mode. Automations require an explicit " "premium model or your own (BYOK) model so every run is billable.", ) - if config_id > 0: - # Positive id → user-owned BYOK config. Always allowed. + if model_id > 0: + # Positive id -> user/search-space BYOK model. Always allowed. return True, "" - # Negative id → global config. Allowed only if premium. - if _is_premium_global(kind, config_id): + # Negative id -> global model. Allowed only if premium. + if _is_premium_global(model_id): return True, "" return ( @@ -99,27 +74,27 @@ def _classify(kind: ModelKind, config_id: int | None) -> tuple[bool, str]: def get_model_eligibility( *, - agent_llm_id: int | None, - image_generation_config_id: int | None, - vision_llm_config_id: int | None, + chat_model_id: int | None, + image_gen_model_id: int | None, + vision_model_id: int | None, ) -> dict: - """Return ``{"allowed": bool, "violations": [...]}`` for explicit config ids. + """Return ``{"allowed": bool, "violations": [...]}`` for explicit model ids. The ID-based core shared by both the search-space path (creation/eligibility) and the captured-snapshot path (runtime backstop). Each violation is ``{"kind", "config_id", "reason"}``. """ checks: list[tuple[ModelKind, int | None]] = [ - ("llm", agent_llm_id), - ("image", image_generation_config_id), - ("vision", vision_llm_config_id), + ("chat", chat_model_id), + ("image", image_gen_model_id), + ("vision", vision_model_id), ] violations: list[dict] = [] for kind, config_id in checks: allowed, reason = _classify(kind, config_id) if not allowed: - violations.append({"kind": kind, "config_id": config_id, "reason": reason}) + violations.append({"kind": kind, "model_id": config_id, "reason": reason}) return {"allowed": not violations, "violations": violations} @@ -131,9 +106,9 @@ def get_automation_model_eligibility(search_space: SearchSpace) -> dict: wrapper over :func:`get_model_eligibility`. """ return get_model_eligibility( - agent_llm_id=search_space.agent_llm_id, - image_generation_config_id=search_space.image_generation_config_id, - vision_llm_config_id=search_space.vision_llm_config_id, + chat_model_id=search_space.chat_model_id, + image_gen_model_id=search_space.image_gen_model_id, + vision_model_id=search_space.vision_model_id, ) @@ -150,9 +125,9 @@ class AutomationModelPolicyError(Exception): def assert_models_billable( *, - agent_llm_id: int | None, - image_generation_config_id: int | None, - vision_llm_config_id: int | None, + chat_model_id: int | None, + image_gen_model_id: int | None, + vision_model_id: int | None, ) -> None: """Raise :class:`AutomationModelPolicyError` if any explicit id is not billable. @@ -160,9 +135,9 @@ def assert_models_billable( captured model snapshot. """ result = get_model_eligibility( - agent_llm_id=agent_llm_id, - image_generation_config_id=image_generation_config_id, - vision_llm_config_id=vision_llm_config_id, + chat_model_id=chat_model_id, + image_gen_model_id=image_gen_model_id, + vision_model_id=vision_model_id, ) if not result["allowed"]: raise AutomationModelPolicyError(result["violations"]) From 32ab2b8713e6b67493ed6873e3e74f11e6013240 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Wed, 10 Jun 2026 21:49:07 +0530 Subject: [PATCH 06/59] feat(web): expose model policies in automations --- .../builder/automation-builder-form.tsx | 2 +- .../builder/automation-model-fields.tsx | 14 +-- .../automations/automations-mutation.atoms.ts | 6 +- .../tool-ui/automation/create-automation.tsx | 18 +-- .../contracts/types/automation.types.ts | 6 +- .../hooks/use-automation-eligible-models.ts | 117 +++++++----------- .../lib/automations/builder-schema.ts | 20 +-- 7 files changed, 77 insertions(+), 106 deletions(-) diff --git a/surfsense_web/app/dashboard/[search_space_id]/automations/components/builder/automation-builder-form.tsx b/surfsense_web/app/dashboard/[search_space_id]/automations/components/builder/automation-builder-form.tsx index 59967080f..a68e53a1c 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/automations/components/builder/automation-builder-form.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/automations/components/builder/automation-builder-form.tsx @@ -130,7 +130,7 @@ export function AutomationBuilderForm({ // data into state, so there's no flicker/loop and the user's pick is sticky. const resolvedModels = useMemo( () => ({ - agentLlmId: form.models.agentLlmId || eligibleModels.llm.defaultId || 0, + chatModelId: form.models.chatModelId || eligibleModels.llm.defaultId || 0, imageConfigId: form.models.imageConfigId || eligibleModels.image.defaultId || 0, visionConfigId: form.models.visionConfigId || eligibleModels.vision.defaultId || 0, }), diff --git a/surfsense_web/app/dashboard/[search_space_id]/automations/components/builder/automation-model-fields.tsx b/surfsense_web/app/dashboard/[search_space_id]/automations/components/builder/automation-model-fields.tsx index 2c4a0bf60..6dd42366b 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/automations/components/builder/automation-model-fields.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/automations/components/builder/automation-model-fields.tsx @@ -25,7 +25,7 @@ import { getProviderIcon } from "@/lib/provider-icons"; import { Field } from "./form-field"; export interface AutomationModelSelection { - agentLlmId: number; + chatModelId: number; imageConfigId: number; visionConfigId: number; } @@ -39,7 +39,7 @@ interface AutomationModelFieldsProps { } /** - * Three eligible-only model pickers (Agent LLM / Image / Vision) for the + * Three eligible-only model pickers (Chat / Image / Vision) for the * automation builder + chat approval card. Options come from * {@link useAutomationEligibleModels} (premium globals + BYOK only); selection * is validated + snapshotted onto `definition.models` at create time. @@ -51,18 +51,18 @@ export function AutomationModelFields({ errors, }: AutomationModelFieldsProps) { const { llm, image, vision, isLoading } = useAutomationEligibleModels(); - const rolesHref = `/dashboard/${searchSpaceId}/search-space-settings/roles`; + const rolesHref = `/dashboard/${searchSpaceId}/search-space-settings/models`; return (
onChange({ agentLlmId: id })} + error={errors?.chatModelId} + onChange={(id) => onChange({ chatModelId: id })} /> ({ task_count: variables.definition.plan.length, trigger_type: variables.triggers?.[0]?.type ?? "none", has_schedule: (variables.triggers?.length ?? 0) > 0, - agent_llm_id: variables.definition.models?.agent_llm_id, - image_generation_config_id: variables.definition.models?.image_generation_config_id, - vision_llm_config_id: variables.definition.models?.vision_llm_config_id, + chat_model_id: variables.definition.models?.chat_model_id, + image_gen_model_id: variables.definition.models?.image_gen_model_id, + vision_model_id: variables.definition.models?.vision_model_id, tags_count: variables.definition.metadata?.tags?.length, }); }, diff --git a/surfsense_web/components/tool-ui/automation/create-automation.tsx b/surfsense_web/components/tool-ui/automation/create-automation.tsx index 24e9d66bd..2a7d09f53 100644 --- a/surfsense_web/components/tool-ui/automation/create-automation.tsx +++ b/surfsense_web/components/tool-ui/automation/create-automation.tsx @@ -113,7 +113,7 @@ function ApprovalCard({ args, interruptData, onDecision }: ApprovalCardProps) { const searchSpaceId = useAtomValue(activeSearchSpaceIdAtom); const eligibleModels = useAutomationEligibleModels(); const [modelSelection, setModelSelection] = useState({ - agentLlmId: 0, + chatModelId: 0, imageConfigId: 0, visionConfigId: 0, }); @@ -121,7 +121,7 @@ function ApprovalCard({ args, interruptData, onDecision }: ApprovalCardProps) { // default. No effect seeds async hook data into state. const resolvedModels = useMemo( () => ({ - agentLlmId: modelSelection.agentLlmId || eligibleModels.llm.defaultId || 0, + chatModelId: modelSelection.chatModelId || eligibleModels.llm.defaultId || 0, imageConfigId: modelSelection.imageConfigId || eligibleModels.image.defaultId || 0, visionConfigId: modelSelection.visionConfigId || eligibleModels.vision.defaultId || 0, }), @@ -133,7 +133,7 @@ function ApprovalCard({ args, interruptData, onDecision }: ApprovalCardProps) { ] ); const modelsResolved = - resolvedModels.agentLlmId !== 0 && + resolvedModels.chatModelId !== 0 && resolvedModels.imageConfigId !== 0 && resolvedModels.visionConfigId !== 0; @@ -147,9 +147,9 @@ function ApprovalCard({ args, interruptData, onDecision }: ApprovalCardProps) { definition: { ...baseDefinition, models: { - agent_llm_id: resolvedModels.agentLlmId, - image_generation_config_id: resolvedModels.imageConfigId, - vision_llm_config_id: resolvedModels.visionConfigId, + chat_model_id: resolvedModels.chatModelId, + image_gen_model_id: resolvedModels.imageConfigId, + vision_model_id: resolvedModels.visionConfigId, }, }, }; @@ -162,9 +162,9 @@ function ApprovalCard({ args, interruptData, onDecision }: ApprovalCardProps) { trigger_type: (triggers[0] as { type?: string } | undefined)?.type ?? (triggers.length ? undefined : "none"), - agent_llm_id: resolvedModels.agentLlmId, - image_generation_config_id: resolvedModels.imageConfigId, - vision_llm_config_id: resolvedModels.visionConfigId, + chat_model_id: resolvedModels.chatModelId, + image_gen_model_id: resolvedModels.imageConfigId, + vision_model_id: resolvedModels.visionConfigId, }); onDecision({ type: "edit", diff --git a/surfsense_web/contracts/types/automation.types.ts b/surfsense_web/contracts/types/automation.types.ts index 45670d245..6331a663c 100644 --- a/surfsense_web/contracts/types/automation.types.ts +++ b/surfsense_web/contracts/types/automation.types.ts @@ -63,9 +63,9 @@ export type Inputs = z.infer; // Captured model snapshot (server-managed). Set at create time and preserved // across edits so runs are insulated from later chat/search-space model changes. export const automationModels = z.object({ - agent_llm_id: z.number().int().default(0), - image_generation_config_id: z.number().int().default(0), - vision_llm_config_id: z.number().int().default(0), + chat_model_id: z.number().int().default(0), + image_gen_model_id: z.number().int().default(0), + vision_model_id: z.number().int().default(0), }); export type AutomationModels = z.infer; diff --git a/surfsense_web/hooks/use-automation-eligible-models.ts b/surfsense_web/hooks/use-automation-eligible-models.ts index e74994221..e75235c56 100644 --- a/surfsense_web/hooks/use-automation-eligible-models.ts +++ b/surfsense_web/hooks/use-automation-eligible-models.ts @@ -3,18 +3,11 @@ import { useAtomValue } from "jotai"; import { useMemo } from "react"; import { - globalImageGenConfigsAtom, - imageGenConfigsAtom, -} from "@/atoms/image-gen-config/image-gen-config-query.atoms"; -import { - globalNewLLMConfigsAtom, - llmPreferencesAtom, - newLLMConfigsAtom, -} from "@/atoms/new-llm-config/new-llm-config-query.atoms"; -import { - globalVisionLLMConfigsAtom, - visionLLMConfigsAtom, -} from "@/atoms/vision-llm-config/vision-llm-config-query.atoms"; + globalModelConnectionsAtom, + modelConnectionsAtom, + modelRolesAtom, +} from "@/atoms/model-connections/model-connections-query.atoms"; +import type { ConnectionRead, ModelRead } from "@/contracts/types/model-connections.types"; /** * A single model the user may pick for an automation slot. @@ -44,48 +37,40 @@ export interface AutomationEligibleModels { isLoading: boolean; } -interface GlobalConfigLike { - id: number; - name: string; - model_name: string; - provider: string; - is_premium?: boolean; - is_auto_mode?: boolean; -} - -interface UserConfigLike { - id: number; - name: string; - model_name: string; - provider: string; -} - /** * Build the eligible option list for one model kind: premium globals - * (`is_premium === true`, never Auto mode) followed by all BYOK configs. + * followed by all BYOK/search-space models. */ function buildKind( - globals: GlobalConfigLike[] | undefined, - byok: UserConfigLike[] | undefined, + globals: ConnectionRead[] | undefined, + byok: ConnectionRead[] | undefined, + capability: "chat" | "image_gen" | "vision", prefId: number | null | undefined ): EligibleModelKind { - const premiumGlobals: EligibleModelOption[] = (globals ?? []) - .filter((c) => c.is_premium === true && !c.is_auto_mode) - .map((c) => ({ - id: c.id, - name: c.name, - modelName: c.model_name, - provider: c.provider, - isBYOK: false, - })); + const toOption = (connection: ConnectionRead, model: ModelRead, isBYOK: boolean) => ({ + id: model.id, + name: model.display_name || model.model_id, + modelName: model.model_id, + provider: connection.native_provider || connection.protocol, + isBYOK, + }); - const byokOptions: EligibleModelOption[] = (byok ?? []).map((c) => ({ - id: c.id, - name: c.name, - modelName: c.model_name, - provider: c.provider, - isBYOK: true, - })); + const premiumGlobals: EligibleModelOption[] = (globals ?? []).flatMap((connection) => + connection.models + .filter( + (model) => + model.enabled && + Boolean(model.capabilities?.[capability]) && + String(model.billing_tier ?? "").toLowerCase() === "premium" + ) + .map((model) => toOption(connection, model, false)) + ); + + const byokOptions: EligibleModelOption[] = (byok ?? []).flatMap((connection) => + connection.models + .filter((model) => model.enabled && Boolean(model.capabilities?.[capability])) + .map((model) => toOption(connection, model, true)) + ); const options = [...premiumGlobals, ...byokOptions]; const byId = new Map(options.map((o) => [o.id, o])); @@ -105,46 +90,32 @@ function buildKind( * (premium globals + user BYOK — never free globals or Auto mode), with a * default selection seeded from the search space's role preferences. * - * Everything is derived during render from the existing config query atoms; + * Everything is derived during render from the connection/model query atoms; * there are no effects, so option lists/maps keep stable references. */ export function useAutomationEligibleModels(): AutomationEligibleModels { - const { data: llmUserConfigs, isLoading: llmUserLoading } = useAtomValue(newLLMConfigsAtom); - const { data: llmGlobalConfigs, isLoading: llmGlobalLoading } = - useAtomValue(globalNewLLMConfigsAtom); - const { data: preferences, isLoading: prefsLoading } = useAtomValue(llmPreferencesAtom); - const { data: imageGlobalConfigs, isLoading: imageGlobalLoading } = - useAtomValue(globalImageGenConfigsAtom); - const { data: imageUserConfigs, isLoading: imageUserLoading } = useAtomValue(imageGenConfigsAtom); - const { data: visionGlobalConfigs, isLoading: visionGlobalLoading } = useAtomValue( - globalVisionLLMConfigsAtom + const { data: byokConnections, isLoading: byokLoading } = useAtomValue(modelConnectionsAtom); + const { data: globalConnections, isLoading: globalLoading } = useAtomValue( + globalModelConnectionsAtom ); - const { data: visionUserConfigs, isLoading: visionUserLoading } = - useAtomValue(visionLLMConfigsAtom); + const { data: roles, isLoading: rolesLoading } = useAtomValue(modelRolesAtom); const llm = useMemo( - () => buildKind(llmGlobalConfigs, llmUserConfigs, preferences?.agent_llm_id), - [llmGlobalConfigs, llmUserConfigs, preferences?.agent_llm_id] + () => buildKind(globalConnections, byokConnections, "chat", roles?.chat_model_id), + [globalConnections, byokConnections, roles?.chat_model_id] ); const image = useMemo( - () => buildKind(imageGlobalConfigs, imageUserConfigs, preferences?.image_generation_config_id), - [imageGlobalConfigs, imageUserConfigs, preferences?.image_generation_config_id] + () => buildKind(globalConnections, byokConnections, "image_gen", roles?.image_gen_model_id), + [globalConnections, byokConnections, roles?.image_gen_model_id] ); const vision = useMemo( - () => buildKind(visionGlobalConfigs, visionUserConfigs, preferences?.vision_llm_config_id), - [visionGlobalConfigs, visionUserConfigs, preferences?.vision_llm_config_id] + () => buildKind(globalConnections, byokConnections, "vision", roles?.vision_model_id), + [globalConnections, byokConnections, roles?.vision_model_id] ); - const isLoading = - llmUserLoading || - llmGlobalLoading || - prefsLoading || - imageGlobalLoading || - imageUserLoading || - visionGlobalLoading || - visionUserLoading; + const isLoading = byokLoading || globalLoading || rolesLoading; return useMemo(() => ({ llm, image, vision, isLoading }), [llm, image, vision, isLoading]); } diff --git a/surfsense_web/lib/automations/builder-schema.ts b/surfsense_web/lib/automations/builder-schema.ts index c2bd69209..5bb034bef 100644 --- a/surfsense_web/lib/automations/builder-schema.ts +++ b/surfsense_web/lib/automations/builder-schema.ts @@ -73,7 +73,7 @@ export type BuilderExecution = z.infer; * later chat/search-space model changes. */ export const builderModelsSchema = z.object({ - agentLlmId: z.number().int(), + chatModelId: z.number().int(), imageConfigId: z.number().int(), visionConfigId: z.number().int(), }); @@ -90,7 +90,7 @@ export const builderFormSchema = z.object({ tags: z.array(z.string()), /** Carried through from an edited definition so we don't drop it. */ goal: z.string().nullable(), - /** Selected agent/image/vision models (``0`` = use the eligible default). */ + /** Selected chat/image/vision models (``0`` = use the eligible default). */ models: builderModelsSchema, }); export type BuilderForm = z.infer; @@ -147,7 +147,7 @@ export function createEmptyForm(): BuilderForm { }, tags: [], goal: null, - models: { agentLlmId: 0, imageConfigId: 0, visionConfigId: 0 }, + models: { chatModelId: 0, imageConfigId: 0, visionConfigId: 0 }, }; } @@ -240,9 +240,9 @@ function buildDefinition(form: BuilderForm): AutomationDefinition { ...(hasResolvedModels(form.models) ? { models: { - agent_llm_id: form.models.agentLlmId, - image_generation_config_id: form.models.imageConfigId, - vision_llm_config_id: form.models.visionConfigId, + chat_model_id: form.models.chatModelId, + image_gen_model_id: form.models.imageConfigId, + vision_model_id: form.models.visionConfigId, }, } : {}), @@ -251,7 +251,7 @@ function buildDefinition(form: BuilderForm): AutomationDefinition { /** True once every model slot holds a concrete (non-zero) id. */ export function hasResolvedModels(models: BuilderModels): boolean { - return models.agentLlmId !== 0 && models.imageConfigId !== 0 && models.visionConfigId !== 0; + return models.chatModelId !== 0 && models.imageConfigId !== 0 && models.visionConfigId !== 0; } /** The desired schedule trigger for this form, or ``null`` if none. */ @@ -500,9 +500,9 @@ function modelsFromDefinition(raw: unknown): BuilderModels { const m = asRecord(raw); const num = (value: unknown) => (typeof value === "number" ? value : 0); return { - agentLlmId: num(m.agent_llm_id), - imageConfigId: num(m.image_generation_config_id), - visionConfigId: num(m.vision_llm_config_id), + chatModelId: num(m.chat_model_id), + imageConfigId: num(m.image_gen_model_id), + visionConfigId: num(m.vision_model_id), }; } From 0674accc23c3474758b39203752ce8bb5c5d1f1d Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Wed, 10 Jun 2026 21:49:21 +0530 Subject: [PATCH 07/59] feat(web): add model connection client data layer --- .../model-connections-mutation.atoms.ts | 129 ++++++++++++++++++ .../model-connections-query.atoms.ts | 32 +++++ .../types/model-connections.types.ts | 98 +++++++++++++ .../lib/apis/model-connections-api.service.ts | 88 ++++++++++++ surfsense_web/lib/query-client/cache-keys.ts | 5 + 5 files changed, 352 insertions(+) create mode 100644 surfsense_web/atoms/model-connections/model-connections-mutation.atoms.ts create mode 100644 surfsense_web/atoms/model-connections/model-connections-query.atoms.ts create mode 100644 surfsense_web/contracts/types/model-connections.types.ts create mode 100644 surfsense_web/lib/apis/model-connections-api.service.ts diff --git a/surfsense_web/atoms/model-connections/model-connections-mutation.atoms.ts b/surfsense_web/atoms/model-connections/model-connections-mutation.atoms.ts new file mode 100644 index 000000000..7d58a402c --- /dev/null +++ b/surfsense_web/atoms/model-connections/model-connections-mutation.atoms.ts @@ -0,0 +1,129 @@ +import { atomWithMutation } from "jotai-tanstack-query"; +import { toast } from "sonner"; +import type { + ConnectionCreateRequest, + ConnectionUpdateRequest, + ModelRoles, + ModelUpdateRequest, +} from "@/contracts/types/model-connections.types"; +import { modelConnectionsApiService } from "@/lib/apis/model-connections-api.service"; +import { cacheKeys } from "@/lib/query-client/cache-keys"; +import { queryClient } from "@/lib/query-client/client"; +import { activeSearchSpaceIdAtom } from "../search-spaces/search-space-query.atoms"; + +function invalidateModelConnections(searchSpaceId: number) { + queryClient.invalidateQueries({ + queryKey: cacheKeys.modelConnections.all(searchSpaceId), + }); + queryClient.invalidateQueries({ + queryKey: cacheKeys.modelConnections.roles(searchSpaceId), + }); +} + +export const createModelConnectionMutationAtom = atomWithMutation((get) => { + const searchSpaceId = Number(get(activeSearchSpaceIdAtom)); + return { + mutationKey: ["model-connections", "create"], + mutationFn: (request: ConnectionCreateRequest) => + modelConnectionsApiService.createConnection(request), + onSuccess: () => { + toast.success("Connection created"); + invalidateModelConnections(searchSpaceId); + }, + onError: (error: Error) => toast.error(error.message || "Failed to create connection"), + }; +}); + +export const updateModelConnectionMutationAtom = atomWithMutation((get) => { + const searchSpaceId = Number(get(activeSearchSpaceIdAtom)); + return { + mutationKey: ["model-connections", "update"], + mutationFn: ({ id, data }: { id: number; data: ConnectionUpdateRequest }) => + modelConnectionsApiService.updateConnection(id, data), + onSuccess: () => { + toast.success("Connection updated"); + invalidateModelConnections(searchSpaceId); + }, + onError: (error: Error) => toast.error(error.message || "Failed to update connection"), + }; +}); + +export const deleteModelConnectionMutationAtom = atomWithMutation((get) => { + const searchSpaceId = Number(get(activeSearchSpaceIdAtom)); + return { + mutationKey: ["model-connections", "delete"], + mutationFn: (id: number) => modelConnectionsApiService.deleteConnection(id), + onSuccess: () => { + toast.success("Connection deleted"); + invalidateModelConnections(searchSpaceId); + }, + onError: (error: Error) => toast.error(error.message || "Failed to delete connection"), + }; +}); + +export const verifyModelConnectionMutationAtom = atomWithMutation((get) => { + const searchSpaceId = Number(get(activeSearchSpaceIdAtom)); + return { + mutationKey: ["model-connections", "verify"], + mutationFn: (id: number) => modelConnectionsApiService.verifyConnection(id), + onSuccess: (result) => { + if (result.ok) toast.success("Connection verified"); + else toast.error(result.message || "Connection failed"); + invalidateModelConnections(searchSpaceId); + }, + onError: (error: Error) => toast.error(error.message || "Failed to verify connection"), + }; +}); + +export const discoverConnectionModelsMutationAtom = atomWithMutation((get) => { + const searchSpaceId = Number(get(activeSearchSpaceIdAtom)); + return { + mutationKey: ["model-connections", "discover"], + mutationFn: (id: number) => modelConnectionsApiService.discoverModels(id), + onSuccess: () => { + toast.success("Models discovered"); + invalidateModelConnections(searchSpaceId); + }, + onError: (error: Error) => toast.error(error.message || "Failed to discover models"), + }; +}); + +export const updateModelMutationAtom = atomWithMutation((get) => { + const searchSpaceId = Number(get(activeSearchSpaceIdAtom)); + return { + mutationKey: ["models", "update"], + mutationFn: ({ id, data }: { id: number; data: ModelUpdateRequest }) => + modelConnectionsApiService.updateModel(id, data), + onSuccess: () => invalidateModelConnections(searchSpaceId), + onError: (error: Error) => toast.error(error.message || "Failed to update model"), + }; +}); + +export const testModelMutationAtom = atomWithMutation((get) => { + const searchSpaceId = Number(get(activeSearchSpaceIdAtom)); + return { + mutationKey: ["models", "test"], + mutationFn: (id: number) => modelConnectionsApiService.testModel(id), + onSuccess: (result) => { + if (result.ok) toast.success("Model test succeeded"); + else toast.error(result.message || "Model test failed"); + invalidateModelConnections(searchSpaceId); + }, + onError: (error: Error) => toast.error(error.message || "Failed to test model"), + }; +}); + +export const updateModelRolesMutationAtom = atomWithMutation((get) => { + const searchSpaceId = Number(get(activeSearchSpaceIdAtom)); + return { + mutationKey: ["model-roles", "update"], + mutationFn: (roles: ModelRoles) => + modelConnectionsApiService.updateModelRoles(searchSpaceId, roles), + onSuccess: () => { + queryClient.invalidateQueries({ + queryKey: cacheKeys.modelConnections.roles(searchSpaceId), + }); + }, + onError: (error: Error) => toast.error(error.message || "Failed to update model roles"), + }; +}); diff --git a/surfsense_web/atoms/model-connections/model-connections-query.atoms.ts b/surfsense_web/atoms/model-connections/model-connections-query.atoms.ts new file mode 100644 index 000000000..617ffe124 --- /dev/null +++ b/surfsense_web/atoms/model-connections/model-connections-query.atoms.ts @@ -0,0 +1,32 @@ +import { atomWithQuery } from "jotai-tanstack-query"; +import { modelConnectionsApiService } from "@/lib/apis/model-connections-api.service"; +import { getBearerToken } from "@/lib/auth-utils"; +import { cacheKeys } from "@/lib/query-client/cache-keys"; +import { activeSearchSpaceIdAtom } from "../search-spaces/search-space-query.atoms"; + +export const globalModelConnectionsAtom = atomWithQuery(() => ({ + queryKey: cacheKeys.modelConnections.global(), + enabled: !!getBearerToken(), + staleTime: 10 * 60 * 1000, + queryFn: () => modelConnectionsApiService.getGlobalConnections(), +})); + +export const modelConnectionsAtom = atomWithQuery((get) => { + const searchSpaceId = Number(get(activeSearchSpaceIdAtom)); + return { + queryKey: cacheKeys.modelConnections.all(searchSpaceId), + enabled: !!searchSpaceId, + staleTime: 5 * 60 * 1000, + queryFn: () => modelConnectionsApiService.getConnections(searchSpaceId), + }; +}); + +export const modelRolesAtom = atomWithQuery((get) => { + const searchSpaceId = Number(get(activeSearchSpaceIdAtom)); + return { + queryKey: cacheKeys.modelConnections.roles(searchSpaceId), + enabled: !!searchSpaceId, + staleTime: 5 * 60 * 1000, + queryFn: () => modelConnectionsApiService.getModelRoles(searchSpaceId), + }; +}); diff --git a/surfsense_web/contracts/types/model-connections.types.ts b/surfsense_web/contracts/types/model-connections.types.ts new file mode 100644 index 000000000..14f93c61a --- /dev/null +++ b/surfsense_web/contracts/types/model-connections.types.ts @@ -0,0 +1,98 @@ +import { z } from "zod"; + +export const connectionProtocolEnum = z.enum(["OLLAMA", "OPENAI_COMPATIBLE", "NATIVE"]); +export const connectionScopeEnum = z.enum(["GLOBAL", "SEARCH_SPACE", "USER"]); +export const modelSourceEnum = z.enum(["DISCOVERED", "MANUAL"]); + +export const modelCapabilities = z.object({ + chat: z.boolean().optional(), + vision: z.boolean().optional(), + image_gen: z.boolean().optional(), + embedding: z.boolean().optional(), + tools: z.boolean().optional(), +}); + +export const modelRead = z.object({ + id: z.number(), + connection_id: z.number(), + model_id: z.string(), + display_name: z.string().nullable().optional(), + source: z.union([modelSourceEnum, z.string()]), + capabilities: z.record(z.string(), z.any()).default({}), + capabilities_declared: z.record(z.string(), z.any()).default({}), + capabilities_verified: z.record(z.string(), z.any()).default({}), + capabilities_override: z.record(z.string(), z.any()).default({}), + embedding_dimension: z.number().nullable().optional(), + enabled: z.boolean(), + billing_tier: z.string().nullable().optional(), + catalog: z.record(z.string(), z.any()).default({}), + created_at: z.string().nullable().optional(), +}); + +export const connectionRead = z.object({ + id: z.number(), + protocol: z.union([connectionProtocolEnum, z.string()]), + native_provider: z.string().nullable().optional(), + base_url: z.string().nullable().optional(), + extra: z.record(z.string(), z.any()).default({}), + scope: z.union([connectionScopeEnum, z.string()]), + search_space_id: z.number().nullable().optional(), + user_id: z.string().nullable().optional(), + enabled: z.boolean(), + has_api_key: z.boolean(), + last_verified_at: z.string().nullable().optional(), + last_status: z.string().nullable().optional(), + last_error: z.string().nullable().optional(), + models: z.array(modelRead).default([]), + created_at: z.string().nullable().optional(), +}); + +export const connectionCreateRequest = z.object({ + protocol: connectionProtocolEnum, + native_provider: z.string().nullable().optional(), + base_url: z.string().nullable().optional(), + api_key: z.string().nullable().optional(), + extra: z.record(z.string(), z.any()).default({}), + scope: connectionScopeEnum.default("SEARCH_SPACE"), + search_space_id: z.number().nullable().optional(), + enabled: z.boolean().default(true), +}); + +export const connectionUpdateRequest = z.object({ + native_provider: z.string().nullable().optional(), + base_url: z.string().nullable().optional(), + api_key: z.string().nullable().optional(), + extra: z.record(z.string(), z.any()).optional(), + enabled: z.boolean().optional(), +}); + +export const modelUpdateRequest = z.object({ + display_name: z.string().nullable().optional(), + enabled: z.boolean().optional(), + capabilities_override: z.record(z.string(), z.any()).optional(), +}); + +export const verifyConnectionResponse = z.object({ + status: z.string(), + ok: z.boolean(), + message: z.string().default(""), +}); + +export const modelRoles = z.object({ + chat_model_id: z.number().nullable().optional(), + vision_model_id: z.number().nullable().optional(), + image_gen_model_id: z.number().nullable().optional(), +}); + +export const connectionListResponse = z.array(connectionRead); +export const modelListResponse = z.array(modelRead); + +export type ConnectionProtocol = z.infer; +export type ConnectionScope = z.infer; +export type ModelRead = z.infer; +export type ConnectionRead = z.infer; +export type ConnectionCreateRequest = z.infer; +export type ConnectionUpdateRequest = z.infer; +export type ModelUpdateRequest = z.infer; +export type ModelRoles = z.infer; +export type VerifyConnectionResponse = z.infer; diff --git a/surfsense_web/lib/apis/model-connections-api.service.ts b/surfsense_web/lib/apis/model-connections-api.service.ts new file mode 100644 index 000000000..ca92ad11b --- /dev/null +++ b/surfsense_web/lib/apis/model-connections-api.service.ts @@ -0,0 +1,88 @@ +import { + type ConnectionCreateRequest, + type ConnectionUpdateRequest, + connectionCreateRequest, + connectionListResponse, + connectionRead, + connectionUpdateRequest, + type ModelRoles, + type ModelUpdateRequest, + modelListResponse, + modelRead, + modelRoles, + modelUpdateRequest, + verifyConnectionResponse, +} from "@/contracts/types/model-connections.types"; +import { ValidationError } from "../error"; +import { baseApiService } from "./base-api.service"; + +class ModelConnectionsApiService { + getGlobalConnections = async () => { + return baseApiService.get(`/api/v1/global-model-connections`, connectionListResponse); + }; + + getConnections = async (searchSpaceId: number) => { + return baseApiService.get( + `/api/v1/model-connections?search_space_id=${searchSpaceId}`, + connectionListResponse + ); + }; + + createConnection = async (request: ConnectionCreateRequest) => { + const parsed = connectionCreateRequest.safeParse(request); + if (!parsed.success) { + throw new ValidationError(parsed.error.issues.map((issue) => issue.message).join(", ")); + } + return baseApiService.post(`/api/v1/model-connections`, connectionRead, { + body: parsed.data, + }); + }; + + updateConnection = async (id: number, request: ConnectionUpdateRequest) => { + const parsed = connectionUpdateRequest.safeParse(request); + if (!parsed.success) { + throw new ValidationError(parsed.error.issues.map((issue) => issue.message).join(", ")); + } + return baseApiService.put(`/api/v1/model-connections/${id}`, connectionRead, { + body: parsed.data, + }); + }; + + deleteConnection = async (id: number) => { + return baseApiService.delete(`/api/v1/model-connections/${id}`, undefined); + }; + + verifyConnection = async (id: number) => { + return baseApiService.post(`/api/v1/model-connections/${id}/verify`, verifyConnectionResponse); + }; + + discoverModels = async (id: number) => { + return baseApiService.post(`/api/v1/model-connections/${id}/discover`, modelListResponse); + }; + + updateModel = async (id: number, request: ModelUpdateRequest) => { + const parsed = modelUpdateRequest.safeParse(request); + if (!parsed.success) { + throw new ValidationError(parsed.error.issues.map((issue) => issue.message).join(", ")); + } + return baseApiService.put(`/api/v1/models/${id}`, modelRead, { + body: parsed.data, + }); + }; + + testModel = async (id: number) => { + return baseApiService.post(`/api/v1/models/${id}/test`, verifyConnectionResponse); + }; + + getModelRoles = async (searchSpaceId: number) => { + return baseApiService.get(`/api/v1/search-spaces/${searchSpaceId}/model-roles`, modelRoles); + }; + + updateModelRoles = async (searchSpaceId: number, roles: ModelRoles) => { + return baseApiService.put(`/api/v1/search-spaces/${searchSpaceId}/model-roles`, modelRoles, { + body: roles, + }); + }; +} + +export const modelConnectionsApiService = new ModelConnectionsApiService(); diff --git a/surfsense_web/lib/query-client/cache-keys.ts b/surfsense_web/lib/query-client/cache-keys.ts index 6f8885d7e..558a73f95 100644 --- a/surfsense_web/lib/query-client/cache-keys.ts +++ b/surfsense_web/lib/query-client/cache-keys.ts @@ -44,6 +44,11 @@ export const cacheKeys = { global: () => ["new-llm-configs", "global"] as const, modelList: () => ["models", "catalogue"] as const, }, + modelConnections: { + all: (searchSpaceId: number) => ["model-connections", searchSpaceId] as const, + global: () => ["model-connections", "global"] as const, + roles: (searchSpaceId: number) => ["model-roles", searchSpaceId] as const, + }, imageGenConfigs: { all: (searchSpaceId: number) => ["image-gen-configs", searchSpaceId] as const, byId: (configId: number) => ["image-gen-configs", "detail", configId] as const, From 4bda0ffa9691c3069798d58a095dac58d9dd2042 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Wed, 10 Jun 2026 21:49:33 +0530 Subject: [PATCH 08/59] feat(settings): add model connection management UI --- .../search-space-settings/layout-shell.tsx | 23 +- .../search-space-settings/models/page.tsx | 4 +- .../components/settings/llm-role-manager.tsx | 4 +- .../settings/model-connections-settings.tsx | 384 ++++++++++++++++++ 4 files changed, 389 insertions(+), 26 deletions(-) create mode 100644 surfsense_web/components/settings/model-connections-settings.tsx diff --git a/surfsense_web/app/dashboard/[search_space_id]/search-space-settings/layout-shell.tsx b/surfsense_web/app/dashboard/[search_space_id]/search-space-settings/layout-shell.tsx index 22f68edab..9d9045004 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/search-space-settings/layout-shell.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/search-space-settings/layout-shell.tsx @@ -5,9 +5,6 @@ import { Bot, CircleUser, Earth, - ImageIcon, - ListChecks, - ScanEye, UserKey, } from "lucide-react"; import Link from "next/link"; @@ -20,10 +17,7 @@ import { cn } from "@/lib/utils"; export type SearchSpaceSettingsTab = | "general" - | "roles" | "models" - | "image-models" - | "vision-models" | "team-roles" | "prompts" | "public-links"; @@ -57,26 +51,11 @@ export function SearchSpaceSettingsLayoutShell({ label: t("nav_general"), icon: , }, - { - value: "roles" as const, - label: t("nav_role_assignments"), - icon: , - }, { value: "models" as const, - label: t("nav_agent_models"), + label: t("nav_models"), icon: , }, - { - value: "image-models" as const, - label: t("nav_image_models"), - icon: , - }, - { - value: "vision-models" as const, - label: t("nav_vision_models"), - icon: , - }, { value: "team-roles" as const, label: t("nav_team_roles"), diff --git a/surfsense_web/app/dashboard/[search_space_id]/search-space-settings/models/page.tsx b/surfsense_web/app/dashboard/[search_space_id]/search-space-settings/models/page.tsx index d68194782..c97ef7630 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/search-space-settings/models/page.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/search-space-settings/models/page.tsx @@ -1,6 +1,6 @@ -import { AgentModelManager } from "@/components/settings/agent-model-manager"; +import { ModelConnectionsSettings } from "@/components/settings/model-connections-settings"; export default async function Page({ params }: { params: Promise<{ search_space_id: string }> }) { const { search_space_id } = await params; - return ; + return ; } diff --git a/surfsense_web/components/settings/llm-role-manager.tsx b/surfsense_web/components/settings/llm-role-manager.tsx index c32e79a8e..547675927 100644 --- a/surfsense_web/components/settings/llm-role-manager.tsx +++ b/surfsense_web/components/settings/llm-role-manager.tsx @@ -48,8 +48,8 @@ import { cn } from "@/lib/utils"; const ROLE_DESCRIPTIONS = { agent: { icon: Bot, - title: "Agent LLM", - description: "Primary LLM for chat interactions and agent operations", + title: "Chat model", + description: "Primary model for chat interactions and agent operations", color: "text-muted-foreground", bgColor: "bg-muted", prefKey: "agent_llm_id" as const, diff --git a/surfsense_web/components/settings/model-connections-settings.tsx b/surfsense_web/components/settings/model-connections-settings.tsx new file mode 100644 index 000000000..5fa4cccf7 --- /dev/null +++ b/surfsense_web/components/settings/model-connections-settings.tsx @@ -0,0 +1,384 @@ +"use client"; + +import { useAtom, useAtomValue } from "jotai"; +import { CheckCircle2, PlugZap, RefreshCcw, XCircle } from "lucide-react"; +import { useMemo, useState } from "react"; +import { + createModelConnectionMutationAtom, + discoverConnectionModelsMutationAtom, + testModelMutationAtom, + updateModelMutationAtom, + updateModelRolesMutationAtom, + verifyModelConnectionMutationAtom, +} from "@/atoms/model-connections/model-connections-mutation.atoms"; +import { + globalModelConnectionsAtom, + modelConnectionsAtom, + modelRolesAtom, +} from "@/atoms/model-connections/model-connections-query.atoms"; +import { Badge } from "@/components/ui/badge"; +import { Button } from "@/components/ui/button"; +import { Card, CardContent, CardDescription, CardHeader, CardTitle } from "@/components/ui/card"; +import { Input } from "@/components/ui/input"; +import { Label } from "@/components/ui/label"; +import { + Select, + SelectContent, + SelectItem, + SelectTrigger, + SelectValue, +} from "@/components/ui/select"; +import type { + ConnectionProtocol, + ConnectionRead, + ModelRead, +} from "@/contracts/types/model-connections.types"; +import { isCloud } from "@/lib/env-config"; +import { getProviderIcon } from "@/lib/provider-icons"; + +type Preset = { + id: string; + label: string; + protocol: ConnectionProtocol; + nativeProvider?: string; + baseUrl?: string; + local?: boolean; +}; + +const PRESETS: Preset[] = [ + { id: "openai", label: "OpenAI", protocol: "NATIVE", nativeProvider: "OPENAI" }, + { id: "anthropic", label: "Anthropic", protocol: "NATIVE", nativeProvider: "ANTHROPIC" }, + { id: "openrouter", label: "OpenRouter", protocol: "NATIVE", nativeProvider: "OPENROUTER" }, + { + id: "ollama", + label: "Ollama", + protocol: "OLLAMA", + baseUrl: "http://host.docker.internal:11434", + local: true, + }, + { + id: "lmstudio", + label: "LM Studio", + protocol: "OPENAI_COMPATIBLE", + baseUrl: "http://host.docker.internal:1234/v1", + local: true, + }, + { + id: "llamacpp", + label: "llama.cpp", + protocol: "OPENAI_COMPATIBLE", + baseUrl: "http://host.docker.internal:8080/v1", + local: true, + }, + { + id: "localai", + label: "LocalAI", + protocol: "OPENAI_COMPATIBLE", + baseUrl: "http://host.docker.internal:8080/v1", + local: true, + }, + { + id: "vllm", + label: "vLLM", + protocol: "OPENAI_COMPATIBLE", + baseUrl: "http://host.docker.internal:8000/v1", + local: true, + }, +]; + +function modelLabel(model: ModelRead) { + return model.display_name || model.model_id; +} + +function capability(model: ModelRead, key: "chat" | "vision" | "image_gen") { + return Boolean(model.capabilities?.[key]); +} + +function StatusBadge({ connection }: { connection: ConnectionRead }) { + if (connection.last_status === "OK") { + return ( + + Healthy + + ); + } + if (connection.last_status) { + return ( + + {connection.last_status} + + ); + } + return Not tested; +} + +function flattenModels(connections: ConnectionRead[]) { + return connections.flatMap((connection) => + connection.models.map((model) => ({ + ...model, + connectionName: connection.native_provider || connection.protocol, + connectionId: connection.id, + provider: connection.native_provider || connection.protocol, + })) + ); +} + +export function ModelConnectionsSettings({ searchSpaceId }: { searchSpaceId: number }) { + const [{ data: globalConnections = [] }] = useAtom(globalModelConnectionsAtom); + const [{ data: connections = [] }] = useAtom(modelConnectionsAtom); + const [{ data: roles }] = useAtom(modelRolesAtom); + const createConnection = useAtomValue(createModelConnectionMutationAtom); + const verifyConnection = useAtomValue(verifyModelConnectionMutationAtom); + const discoverModels = useAtomValue(discoverConnectionModelsMutationAtom); + const updateModel = useAtomValue(updateModelMutationAtom); + const testModel = useAtomValue(testModelMutationAtom); + const updateRoles = useAtomValue(updateModelRolesMutationAtom); + + const visiblePresets = useMemo( + () => PRESETS.filter((preset) => !(isCloud() && preset.local)), + [] + ); + const [presetId, setPresetId] = useState(visiblePresets[0]?.id ?? "openai"); + const preset = visiblePresets.find((item) => item.id === presetId) ?? visiblePresets[0]; + const [baseUrl, setBaseUrl] = useState(preset?.baseUrl ?? ""); + const [apiKey, setApiKey] = useState(""); + + const allConnections = [...globalConnections, ...connections]; + const enabledModels = flattenModels(allConnections).filter((model) => model.enabled); + const chatModels = enabledModels.filter((model) => capability(model, "chat")); + const visionModels = enabledModels.filter((model) => capability(model, "vision")); + const imageModels = enabledModels.filter((model) => capability(model, "image_gen")); + + function onPresetChange(value: string) { + setPresetId(value); + const next = visiblePresets.find((item) => item.id === value); + setBaseUrl(next?.baseUrl ?? ""); + } + + function handleCreate() { + if (!preset) return; + createConnection.mutate({ + protocol: preset.protocol, + native_provider: preset.nativeProvider, + base_url: baseUrl || null, + api_key: apiKey || null, + scope: "SEARCH_SPACE", + search_space_id: searchSpaceId, + extra: {}, + enabled: true, + }); + } + + function renderModelOption(model: ModelRead & { connectionName: string; provider: string }) { + return ( + + + {getProviderIcon(model.provider, { className: "size-4" })} + {modelLabel(model)} · {model.connectionName} + + + ); + } + + return ( +
+ + + Model Connections + + Add credentials or local endpoints once, then discover reusable models. + + + +
+
+ + +
+
+ + setBaseUrl(event.target.value)} + placeholder="https://api.example.com/v1" + /> +
+
+ + setApiKey(event.target.value)} + placeholder="Optional for local models" + type="password" + /> +
+
+ +
+
+ {preset?.local ? ( +

+ Local URLs are tested from the backend container. Use host.docker.internal instead of + localhost. +

+ ) : null} + +
+ {connections.map((connection) => ( +
+
+
+
+ {getProviderIcon(connection.native_provider || connection.protocol, { + className: "size-4", + })} + {connection.native_provider || connection.protocol} +
+
+ {connection.base_url || "Provider default endpoint"} +
+
+
+ + + +
+
+ {connection.last_error ? ( +

{connection.last_error}

+ ) : null} +
+ {connection.models.map((model) => ( +
+
+
+ {getProviderIcon(connection.native_provider || connection.protocol, { + className: "size-4", + })} + {modelLabel(model)} +
+
+ {["chat", "vision", "image_gen"] + .filter((key) => Boolean(model.capabilities?.[key])) + .join(", ") || "No verified capabilities"} +
+
+
+ + +
+
+ ))} +
+
+ ))} +
+
+
+ + + + Model Roles + + Pick which enabled model powers chat, vision, and image generation for this search + space. + + + +
+ + +
+
+ + +
+
+ + +
+
+
+
+ ); +} From 39cca36c31bd17483b365951b361c0ff4e70c4ec Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Wed, 10 Jun 2026 21:49:55 +0530 Subject: [PATCH 09/59] feat(onboarding): include model connections in setup flow --- .../[search_space_id]/client-layout.tsx | 56 ++--- .../[search_space_id]/onboard/page.tsx | 216 +++++------------- surfsense_web/lib/posthog/events.ts | 12 +- 3 files changed, 96 insertions(+), 188 deletions(-) diff --git a/surfsense_web/app/dashboard/[search_space_id]/client-layout.tsx b/surfsense_web/app/dashboard/[search_space_id]/client-layout.tsx index 3a41b5998..c7e05fe99 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/client-layout.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/client-layout.tsx @@ -4,15 +4,15 @@ import { useAtomValue, useSetAtom } from "jotai"; import { useParams, usePathname, useRouter } from "next/navigation"; import { useTranslations } from "next-intl"; import type React from "react"; -import { useCallback, useEffect, useRef, useState } from "react"; +import { useCallback, useEffect, useMemo, useRef, useState } from "react"; import { toast } from "sonner"; import { pendingUserImageDataUrlsAtom } from "@/atoms/chat/pending-user-images.atom"; import { myAccessAtom } from "@/atoms/members/members-query.atoms"; -import { updateLLMPreferencesMutationAtom } from "@/atoms/new-llm-config/new-llm-config-mutation.atoms"; +import { updateModelRolesMutationAtom } from "@/atoms/model-connections/model-connections-mutation.atoms"; import { - globalNewLLMConfigsAtom, - llmPreferencesAtom, -} from "@/atoms/new-llm-config/new-llm-config-query.atoms"; + globalModelConnectionsAtom, + modelRolesAtom, +} from "@/atoms/model-connections/model-connections-query.atoms"; import { activeSearchSpaceIdAtom } from "@/atoms/search-spaces/search-space-query.atoms"; import { DocumentUploadDialogProvider } from "@/components/assistant-ui/document-upload-popup"; import { LayoutDataProvider } from "@/components/layout"; @@ -21,7 +21,6 @@ import { Card, CardContent, CardDescription, CardHeader, CardTitle } from "@/com import { useFolderSync } from "@/hooks/use-folder-sync"; import { useGlobalLoadingEffect } from "@/hooks/use-global-loading"; import { useElectronAPI } from "@/hooks/use-platform"; -import { isLlmOnboardingComplete } from "@/lib/onboarding"; export function DashboardClientLayout({ children, @@ -38,18 +37,27 @@ export function DashboardClientLayout({ const setPendingUserImageUrls = useSetAtom(pendingUserImageDataUrlsAtom); const { - data: preferences = {}, + data: modelRoles = {}, isFetching: loading, error, - refetch: refetchPreferences, - } = useAtomValue(llmPreferencesAtom); - const { data: globalConfigs = [], isFetching: globalConfigsLoading } = - useAtomValue(globalNewLLMConfigsAtom); - const { mutateAsync: updatePreferences } = useAtomValue(updateLLMPreferencesMutationAtom); + refetch: refetchModelRoles, + } = useAtomValue(modelRolesAtom); + const { data: globalConnections = [], isFetching: globalConfigsLoading } = useAtomValue( + globalModelConnectionsAtom + ); + const { mutateAsync: updateModelRoles } = useAtomValue(updateModelRolesMutationAtom); + + const firstGlobalChatModel = useMemo(() => { + for (const connection of globalConnections) { + const model = connection.models.find((item) => item.enabled && item.capabilities?.chat); + if (model) return model; + } + return null; + }, [globalConnections]); const isOnboardingComplete = useCallback(() => { - return isLlmOnboardingComplete(preferences.agent_llm_id, globalConfigs.length > 0); - }, [preferences.agent_llm_id, globalConfigs.length]); + return (modelRoles.chat_model_id ?? 0) !== 0 || Boolean(firstGlobalChatModel); + }, [modelRoles.chat_model_id, firstGlobalChatModel]); const { data: access = null, isLoading: accessLoading } = useAtomValue(myAccessAtom); const [hasCheckedOnboarding, setHasCheckedOnboarding] = useState(false); @@ -84,24 +92,18 @@ export function DashboardClientLayout({ return; } - if (globalConfigs.length > 0 && !hasAttemptedAutoConfig.current) { + if (firstGlobalChatModel && !hasAttemptedAutoConfig.current) { hasAttemptedAutoConfig.current = true; setIsAutoConfiguring(true); const autoConfigureWithGlobal = async () => { try { - const firstGlobalConfig = globalConfigs[0]; - await updatePreferences({ - search_space_id: Number(searchSpaceId), - data: { - agent_llm_id: firstGlobalConfig.id, - }, - }); + await updateModelRoles({ chat_model_id: firstGlobalChatModel.id }); - await refetchPreferences(); + await refetchModelRoles(); toast.success("AI configured automatically!", { - description: `Using ${firstGlobalConfig.name}. Customize in Settings.`, + description: `Using ${firstGlobalChatModel.display_name || firstGlobalChatModel.model_id}. Customize in Settings.`, }); setHasCheckedOnboarding(true); @@ -128,12 +130,12 @@ export function DashboardClientLayout({ isOnboardingPage, isOwner, isAutoConfiguring, - globalConfigs, + firstGlobalChatModel, router, searchSpaceId, hasCheckedOnboarding, - updatePreferences, - refetchPreferences, + updateModelRoles, + refetchModelRoles, ]); const electronAPI = useElectronAPI(); diff --git a/surfsense_web/app/dashboard/[search_space_id]/onboard/page.tsx b/surfsense_web/app/dashboard/[search_space_id]/onboard/page.tsx index de5c961e8..9cf429a3a 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/onboard/page.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/onboard/page.tsx @@ -2,193 +2,99 @@ import { useAtomValue } from "jotai"; import { useParams, useRouter } from "next/navigation"; -import { useEffect, useRef, useState } from "react"; +import { useEffect, useMemo, useRef, useState } from "react"; import { toast } from "sonner"; +import { updateModelRolesMutationAtom } from "@/atoms/model-connections/model-connections-mutation.atoms"; import { - createNewLLMConfigMutationAtom, - updateLLMPreferencesMutationAtom, -} from "@/atoms/new-llm-config/new-llm-config-mutation.atoms"; -import { - globalNewLLMConfigsAtom, - llmPreferencesAtom, -} from "@/atoms/new-llm-config/new-llm-config-query.atoms"; + globalModelConnectionsAtom, + modelRolesAtom, +} from "@/atoms/model-connections/model-connections-query.atoms"; import { Logo } from "@/components/Logo"; -import { LLMConfigForm, type LLMConfigFormData } from "@/components/shared/llm-config-form"; import { Button } from "@/components/ui/button"; import { Spinner } from "@/components/ui/spinner"; import { useGlobalLoadingEffect } from "@/hooks/use-global-loading"; import { getBearerToken, redirectToLogin } from "@/lib/auth-utils"; -import { isLlmOnboardingComplete } from "@/lib/onboarding"; export default function OnboardPage() { const router = useRouter(); const params = useParams(); const searchSpaceId = Number(params.search_space_id); - // Queries - const { - data: globalConfigs = [], - isFetching: globalConfigsLoading, - isSuccess: globalConfigsLoaded, - } = useAtomValue(globalNewLLMConfigsAtom); - const { data: preferences = {}, isFetching: preferencesLoading } = - useAtomValue(llmPreferencesAtom); - - // Mutations - const { mutateAsync: createConfig, isPending: isCreating } = useAtomValue( - createNewLLMConfigMutationAtom + const { data: globalConnections = [], isFetching: globalLoading } = useAtomValue( + globalModelConnectionsAtom ); - const { mutateAsync: updatePreferences, isPending: isUpdatingPreferences } = useAtomValue( - updateLLMPreferencesMutationAtom - ); - - // State + const { data: roles = {}, isFetching: rolesLoading } = useAtomValue(modelRolesAtom); + const { mutateAsync: updateRoles, isPending } = useAtomValue(updateModelRolesMutationAtom); const [isAutoConfiguring, setIsAutoConfiguring] = useState(false); const hasAttemptedAutoConfig = useRef(false); - // Check authentication useEffect(() => { - const token = getBearerToken(); - if (!token) { - redirectToLogin(); - } + if (!getBearerToken()) redirectToLogin(); }, []); - const isOnboardingComplete = isLlmOnboardingComplete( - preferences.agent_llm_id, - globalConfigs.length > 0 - ); - - useEffect(() => { - if (!preferencesLoading && globalConfigsLoaded && isOnboardingComplete) { - router.push(`/dashboard/${searchSpaceId}/new-chat`); + const firstGlobalChatModel = useMemo(() => { + for (const connection of globalConnections) { + const model = connection.models.find((item) => item.enabled && item.capabilities?.chat); + if (model) return model; } - }, [preferencesLoading, globalConfigsLoaded, isOnboardingComplete, router, searchSpaceId]); + return null; + }, [globalConnections]); + + const isComplete = (roles.chat_model_id ?? 0) !== 0 || Boolean(firstGlobalChatModel); useEffect(() => { - const autoConfigureWithGlobal = async () => { - if (hasAttemptedAutoConfig.current) return; - if (globalConfigsLoading || preferencesLoading) return; - if (!globalConfigsLoaded) return; - if (isOnboardingComplete) return; + if (globalLoading || rolesLoading || hasAttemptedAutoConfig.current) return; + if ((roles.chat_model_id ?? 0) !== 0) { + router.push(`/dashboard/${searchSpaceId}/new-chat`); + return; + } + if (!firstGlobalChatModel) return; - if (globalConfigs.length > 0) { - hasAttemptedAutoConfig.current = true; - setIsAutoConfiguring(true); - - try { - const firstGlobalConfig = globalConfigs[0]; - - await updatePreferences({ - search_space_id: searchSpaceId, - data: { - agent_llm_id: firstGlobalConfig.id, - }, - }); - - toast.success("AI configured automatically!", { - description: `Using ${firstGlobalConfig.name}. You can customize this later in Settings.`, - }); - - router.push(`/dashboard/${searchSpaceId}/new-chat`); - } catch (error) { - console.error("Auto-configuration failed:", error); - toast.error("Auto-configuration failed. Please add a configuration manually."); - setIsAutoConfiguring(false); - } - } - }; - - autoConfigureWithGlobal(); + hasAttemptedAutoConfig.current = true; + setIsAutoConfiguring(true); + updateRoles({ chat_model_id: firstGlobalChatModel.id }) + .then(() => { + toast.success("AI configured automatically", { + description: `Using ${firstGlobalChatModel.display_name || firstGlobalChatModel.model_id}.`, + }); + router.push(`/dashboard/${searchSpaceId}/new-chat`); + }) + .catch((error) => { + console.error("Auto-configuration failed:", error); + toast.error("Auto-configuration failed. Add a connection manually."); + setIsAutoConfiguring(false); + }); }, [ - globalConfigs, - globalConfigsLoading, - globalConfigsLoaded, - preferencesLoading, - isOnboardingComplete, - updatePreferences, - searchSpaceId, + firstGlobalChatModel, + globalLoading, + roles.chat_model_id, + rolesLoading, router, + searchSpaceId, + updateRoles, ]); - const handleSubmit = async (formData: LLMConfigFormData) => { - try { - const newConfig = await createConfig(formData); - - await updatePreferences({ - search_space_id: searchSpaceId, - data: { - agent_llm_id: newConfig.id, - }, - }); - - toast.success("Configuration created!", { - description: "Redirecting to chat...", - }); - - router.push(`/dashboard/${searchSpaceId}/new-chat`); - } catch (error) { - console.error("Failed to create config:", error); - if (error instanceof Error) { - toast.error(error.message || "Failed to create configuration"); - } - } - }; - - const isSubmitting = isCreating || isUpdatingPreferences; - - const isLoading = globalConfigsLoading || preferencesLoading || isAutoConfiguring; + const isLoading = globalLoading || rolesLoading || isAutoConfiguring || isPending; useGlobalLoadingEffect(isLoading); - if (isLoading) { - return null; - } - - if (globalConfigs.length > 0 && !isAutoConfiguring) { - return null; - } + if (isLoading || isComplete) return null; return ( -
-
- {/* Header */} -
- -
-

Configure Your AI

-

- Add your LLM provider to get started with SurfSense -

-
-
- - {/* Form card */} -
- -
- - {/* Footer */} -
- -

You can add more configurations later

+
+
+ +
+

Connect a Model

+

+ Add one connection, discover its models, then choose a chat model for this search space. +

+ + {isPending ? : null}
); diff --git a/surfsense_web/lib/posthog/events.ts b/surfsense_web/lib/posthog/events.ts index 4dc644d5e..5aac8943d 100644 --- a/surfsense_web/lib/posthog/events.ts +++ b/surfsense_web/lib/posthog/events.ts @@ -609,9 +609,9 @@ interface AutomationCreatedProps { task_count?: number; trigger_type?: string; has_schedule?: boolean; - agent_llm_id?: number; - image_generation_config_id?: number; - vision_llm_config_id?: number; + chat_model_id?: number; + image_gen_model_id?: number; + vision_model_id?: number; tags_count?: number; } @@ -705,9 +705,9 @@ interface AutomationChatDecisionProps { edited?: boolean; task_count?: number; trigger_type?: string; - agent_llm_id?: number; - image_generation_config_id?: number; - vision_llm_config_id?: number; + chat_model_id?: number; + image_gen_model_id?: number; + vision_model_id?: number; } export function trackAutomationChatApproved(props: AutomationChatDecisionProps) { From e46283992907102421ef6d95fcdb92f9f3ca989e Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Wed, 10 Jun 2026 21:50:10 +0530 Subject: [PATCH 10/59] refactor(chat): simplify model selector connection flow --- .../components/new-chat/model-selector.tsx | 1542 ++--------------- surfsense_web/content/docs/how-to/ollama.mdx | 2 +- 2 files changed, 173 insertions(+), 1371 deletions(-) diff --git a/surfsense_web/components/new-chat/model-selector.tsx b/surfsense_web/components/new-chat/model-selector.tsx index 0a096f5f8..7c912afbb 100644 --- a/surfsense_web/components/new-chat/model-selector.tsx +++ b/surfsense_web/components/new-chat/model-selector.tsx @@ -1,53 +1,27 @@ "use client"; -import { useAtomValue } from "jotai"; +import { useAtom, useAtomValue } from "jotai"; +import { Bot, Check, ChevronDown, ImageOff, Search, Settings2, Zap } from "lucide-react"; +import { useMemo, useState } from "react"; +import { updateModelRolesMutationAtom } from "@/atoms/model-connections/model-connections-mutation.atoms"; import { - Bot, - Check, - ChevronDown, - ChevronLeft, - ChevronRight, - ChevronUp, - ImageIcon, - Layers, - Pencil, - Plus, - ScanEye, - Search, - Zap, -} from "lucide-react"; -import type React from "react"; -import { Fragment, useCallback, useEffect, useMemo, useRef, useState } from "react"; -import { toast } from "sonner"; -import { pendingUserImageDataUrlsAtom } from "@/atoms/chat/pending-user-images.atom"; -import { - globalImageGenConfigsAtom, - imageGenConfigsAtom, -} from "@/atoms/image-gen-config/image-gen-config-query.atoms"; -import { updateLLMPreferencesMutationAtom } from "@/atoms/new-llm-config/new-llm-config-mutation.atoms"; -import { - globalNewLLMConfigsAtom, - llmPreferencesAtom, - newLLMConfigsAtom, -} from "@/atoms/new-llm-config/new-llm-config-query.atoms"; -import { activeSearchSpaceIdAtom } from "@/atoms/search-spaces/search-space-query.atoms"; -import { - globalVisionLLMConfigsAtom, - visionLLMConfigsAtom, -} from "@/atoms/vision-llm-config/vision-llm-config-query.atoms"; + globalModelConnectionsAtom, + modelConnectionsAtom, + modelRolesAtom, +} from "@/atoms/model-connections/model-connections-query.atoms"; import { Badge } from "@/components/ui/badge"; import { Button } from "@/components/ui/button"; import { Drawer, DrawerContent, - DrawerHandle, DrawerHeader, DrawerTitle, DrawerTrigger, } from "@/components/ui/drawer"; +import { Input } from "@/components/ui/input"; import { Popover, PopoverContent, PopoverTrigger } from "@/components/ui/popover"; import { Spinner } from "@/components/ui/spinner"; -import { Tooltip, TooltipContent, TooltipTrigger } from "@/components/ui/tooltip"; +import type { ConnectionRead, ModelRead } from "@/contracts/types/model-connections.types"; import type { GlobalImageGenConfig, GlobalNewLLMConfig, @@ -60,272 +34,6 @@ import { useIsMobile } from "@/hooks/use-mobile"; import { getProviderIcon } from "@/lib/provider-icons"; import { cn } from "@/lib/utils"; -// ─── Helpers ──────────────────────────────────────────────────────── - -const PROVIDER_NAMES: Record = { - OPENAI: "OpenAI", - ANTHROPIC: "Anthropic", - GOOGLE: "Google", - AZURE: "Azure", - AZURE_OPENAI: "Azure OpenAI", - AWS_BEDROCK: "AWS Bedrock", - BEDROCK: "Bedrock", - DEEPSEEK: "DeepSeek", - MISTRAL: "Mistral", - COHERE: "Cohere", - GITHUB_MODELS: "GitHub Models", - GROQ: "Groq", - OLLAMA: "Ollama", - TOGETHER_AI: "Together AI", - FIREWORKS_AI: "Fireworks AI", - REPLICATE: "Replicate", - HUGGINGFACE: "HuggingFace", - PERPLEXITY: "Perplexity", - XAI: "xAI", - OPENROUTER: "OpenRouter", - CEREBRAS: "Cerebras", - SAMBANOVA: "SambaNova", - VERTEX_AI: "Vertex AI", - MINIMAX: "MiniMax", - MOONSHOT: "Moonshot", - ZHIPU: "Zhipu", - DEEPINFRA: "DeepInfra", - CLOUDFLARE: "Cloudflare", - DATABRICKS: "Databricks", - NSCALE: "NScale", - RECRAFT: "Recraft", - XINFERENCE: "XInference", - CUSTOM: "Custom", - AI21: "AI21", - ALIBABA_QWEN: "Qwen", - ANYSCALE: "Anyscale", - COMETAPI: "CometAPI", -}; - -// Provider keys valid per model type, matching backend enums -// (LiteLLMProvider, ImageGenProvider, VisionProvider in db.py) -const LLM_PROVIDER_KEYS: string[] = [ - "OPENAI", - "ANTHROPIC", - "GOOGLE", - "AZURE_OPENAI", - "BEDROCK", - "VERTEX_AI", - "GROQ", - "DEEPSEEK", - "XAI", - "MISTRAL", - "COHERE", - "OPENROUTER", - "TOGETHER_AI", - "FIREWORKS_AI", - "REPLICATE", - "PERPLEXITY", - "OLLAMA", - "CEREBRAS", - "SAMBANOVA", - "DEEPINFRA", - "AI21", - "ALIBABA_QWEN", - "MOONSHOT", - "ZHIPU", - "MINIMAX", - "HUGGINGFACE", - "CLOUDFLARE", - "DATABRICKS", - "ANYSCALE", - "COMETAPI", - "GITHUB_MODELS", - "CUSTOM", -]; - -const IMAGE_PROVIDER_KEYS: string[] = [ - "OPENAI", - "AZURE_OPENAI", - "GOOGLE", - "VERTEX_AI", - "BEDROCK", - "RECRAFT", - "OPENROUTER", - "XINFERENCE", - "NSCALE", -]; - -const VISION_PROVIDER_KEYS: string[] = [ - "OPENAI", - "ANTHROPIC", - "GOOGLE", - "AZURE_OPENAI", - "VERTEX_AI", - "BEDROCK", - "XAI", - "OPENROUTER", - "OLLAMA", - "GROQ", - "TOGETHER_AI", - "FIREWORKS_AI", - "DEEPSEEK", - "MISTRAL", - "CUSTOM", -]; - -const PROVIDER_KEYS_BY_TAB: Record = { - llm: LLM_PROVIDER_KEYS, - image: IMAGE_PROVIDER_KEYS, - vision: VISION_PROVIDER_KEYS, -}; - -function formatProviderName(provider: string): string { - const key = provider.toUpperCase(); - return ( - PROVIDER_NAMES[key] ?? - provider.charAt(0).toUpperCase() + provider.slice(1).toLowerCase().replace(/_/g, " ") - ); -} - -function normalizeText(input: string): string { - return input - .normalize("NFD") - .replace(/\p{Diacritic}/gu, "") - .toLowerCase() - .replace(/[^a-z0-9]+/g, " ") - .trim(); -} - -interface ConfigBase { - id: number; - name: string; - model_name: string; - provider: string; -} - -function filterAndScore( - configs: T[], - selectedProvider: string, - searchQuery: string -): T[] { - let result = configs; - - if (selectedProvider !== "all") { - result = result.filter((c) => c.provider.toUpperCase() === selectedProvider); - } - - if (!searchQuery.trim()) return result; - - const normalized = normalizeText(searchQuery); - const tokens = normalized.split(/\s+/).filter(Boolean); - - const scored = result.map((c) => { - const aggregate = normalizeText([c.name, c.model_name, c.provider].join(" ")); - let score = 0; - if (aggregate.includes(normalized)) score += 5; - for (const token of tokens) { - if (aggregate.includes(token)) score += 1; - } - return { config: c, score }; - }); - - return scored - .filter((s) => s.score > 0) - .sort((a, b) => b.score - a.score) - .map((s) => s.config); -} - -interface DisplayItem { - config: ConfigBase & Record; - isGlobal: boolean; - isAutoMode: boolean; -} - -const TruncatedNameWithTooltip: React.FC<{ - text: string; - className?: string; - enableTooltip: boolean; -}> = ({ text, className, enableTooltip }) => { - const textRef = useRef(null); - const openTimerRef = useRef(undefined); - const [isTruncated, setIsTruncated] = useState(false); - const [open, setOpen] = useState(false); - - const recalcTruncation = useCallback(() => { - const el = textRef.current; - if (!el) return; - setIsTruncated(el.scrollWidth > el.clientWidth + 1); - }, []); - - useEffect(() => { - if (!enableTooltip) return; - const el = textRef.current; - if (!el) return; - - const raf = requestAnimationFrame(recalcTruncation); - recalcTruncation(); - - const observer = new ResizeObserver(recalcTruncation); - observer.observe(el); - if (el.parentElement) observer.observe(el.parentElement); - window.addEventListener("resize", recalcTruncation); - - return () => { - cancelAnimationFrame(raf); - observer.disconnect(); - window.removeEventListener("resize", recalcTruncation); - }; - }, [enableTooltip, recalcTruncation]); - - useEffect(() => { - // Recompute when row text changes. - void text; - requestAnimationFrame(recalcTruncation); - }, [text, recalcTruncation]); - - useEffect( - () => () => { - if (openTimerRef.current) window.clearTimeout(openTimerRef.current); - }, - [] - ); - - if (!enableTooltip) { - return ( - - {text} - - ); - } - - const handleOpenChange = (nextOpen: boolean) => { - if (openTimerRef.current) { - window.clearTimeout(openTimerRef.current); - openTimerRef.current = undefined; - } - if (!nextOpen) { - setOpen(false); - return; - } - if (!isTruncated) return; - openTimerRef.current = window.setTimeout(() => { - setOpen(true); - openTimerRef.current = undefined; - }, 220); - }; - - return ( - - - - {text} - - - - {text} - - - ); -}; - -// ─── Component ────────────────────────────────────────────────────── - interface ModelSelectorProps { onEditLLM: (config: NewLLMConfigPublic | GlobalNewLLMConfig, isGlobal: boolean) => void; onAddNewLLM: (provider?: string) => void; @@ -336,1113 +44,207 @@ interface ModelSelectorProps { className?: string; } +type ChatModel = ModelRead & { + connectionId: number; + connectionLabel: string; + provider: string; +}; + +function modelName(model: ModelRead) { + return model.display_name || model.model_id; +} + +function connectionLabel(connection: ConnectionRead) { + if (connection.scope === "GLOBAL") return "Hosted"; + return connection.native_provider || connection.protocol; +} + +function flattenChatModels(connections: ConnectionRead[]) { + return connections.flatMap((connection) => + connection.models + .filter((model) => model.enabled && Boolean(model.capabilities?.chat)) + .map((model) => ({ + ...model, + connectionId: connection.id, + connectionLabel: connectionLabel(connection), + provider: connection.native_provider || connection.protocol, + })) + ); +} + +function groupedModels(models: ChatModel[]) { + return models.reduce>((groups, model) => { + const key = model.connectionLabel; + if (!groups[key]) groups[key] = []; + groups[key].push(model); + return groups; + }, {}); +} + export function ModelSelector({ - onEditLLM, onAddNewLLM, + onEditLLM, onEditImage, onAddNewImage, onEditVision, onAddNewVision, className, }: ModelSelectorProps) { - const [open, setOpen] = useState(false); - const [activeTab, setActiveTab] = useState<"llm" | "image" | "vision">("llm"); - const [searchQuery, setSearchQuery] = useState(""); - const [selectedProvider, setSelectedProvider] = useState("all"); - const [focusedIndex, setFocusedIndex] = useState(-1); - const [modelScrollPos, setModelScrollPos] = useState<"top" | "middle" | "bottom">("top"); - const [sidebarScrollPos, setSidebarScrollPos] = useState<"top" | "middle" | "bottom">("top"); - const providerSidebarRef = useRef(null); - const modelListRef = useRef(null); - const searchInputRef = useRef(null); + void onEditLLM; + void onEditImage; + void onAddNewImage; + void onEditVision; + void onAddNewVision; + const isMobile = useIsMobile(); - - const handleOpenChange = useCallback( - (next: boolean) => { - if (next) { - setSearchQuery(""); - setSelectedProvider("all"); - if (!isMobile) { - requestAnimationFrame(() => searchInputRef.current?.focus()); - } - } - setOpen(next); - }, - [isMobile] + const [open, setOpen] = useState(false); + const [search, setSearch] = useState(""); + const [{ data: globalConnections = [], isLoading: globalLoading }] = useAtom( + globalModelConnectionsAtom ); + const [{ data: connections = [], isLoading: connectionsLoading }] = useAtom(modelConnectionsAtom); + const [{ data: roles }] = useAtom(modelRolesAtom); + const updateRoles = useAtomValue(updateModelRolesMutationAtom); - const handleTabChange = useCallback( - (next: "llm" | "image" | "vision") => { - setActiveTab(next); - setSelectedProvider("all"); - setSearchQuery(""); - setFocusedIndex(-1); - setModelScrollPos("top"); - if (open && !isMobile) { - requestAnimationFrame(() => searchInputRef.current?.focus()); - } - }, - [open, isMobile] - ); - - const handleModelListScroll = useCallback((e: React.UIEvent) => { - const el = e.currentTarget; - const atTop = el.scrollTop <= 2; - const atBottom = el.scrollHeight - el.scrollTop - el.clientHeight <= 2; - setModelScrollPos(atTop ? "top" : atBottom ? "bottom" : "middle"); - }, []); - - const handleSidebarScroll = useCallback( - (e: React.UIEvent) => { - const el = e.currentTarget; - if (isMobile) { - const atStart = el.scrollLeft <= 2; - const atEnd = el.scrollWidth - el.scrollLeft - el.clientWidth <= 2; - setSidebarScrollPos(atStart ? "top" : atEnd ? "bottom" : "middle"); - } else { - const atTop = el.scrollTop <= 2; - const atBottom = el.scrollHeight - el.scrollTop - el.clientHeight <= 2; - setSidebarScrollPos(atTop ? "top" : atBottom ? "bottom" : "middle"); - } - }, - [isMobile] - ); - - const scrollProviderSidebar = useCallback( - (direction: "backward" | "forward") => { - const el = providerSidebarRef.current; - if (!el) return; - const delta = isMobile - ? Math.max(56, Math.floor(el.clientWidth * 0.5)) - : Math.max(44, Math.floor(el.clientHeight * 0.4)); - - if (isMobile) { - el.scrollBy({ - left: direction === "backward" ? -delta : delta, - behavior: "smooth", - }); - return; - } - - el.scrollBy({ - top: direction === "backward" ? -delta : delta, - behavior: "smooth", - }); - }, - [isMobile] - ); - - // Cmd/Ctrl+M shortcut (desktop only) - useEffect(() => { - if (isMobile) return; - const handler = (e: KeyboardEvent) => { - if ((e.metaKey || e.ctrlKey) && e.key === "m") { - e.preventDefault(); - // setOpen((prev) => !prev); - handleOpenChange(!open); - } - }; - document.addEventListener("keydown", handler); - return () => document.removeEventListener("keydown", handler); - }, [isMobile, open, handleOpenChange]); - - // ─── Data ─── - const { data: llmUserConfigs, isLoading: llmUserLoading } = useAtomValue(newLLMConfigsAtom); - const { data: llmGlobalConfigs, isLoading: llmGlobalLoading } = - useAtomValue(globalNewLLMConfigsAtom); - const { data: preferences, isLoading: prefsLoading } = useAtomValue(llmPreferencesAtom); - const searchSpaceId = useAtomValue(activeSearchSpaceIdAtom); - const { mutateAsync: updatePreferences } = useAtomValue(updateLLMPreferencesMutationAtom); - const { data: imageGlobalConfigs, isLoading: imageGlobalLoading } = - useAtomValue(globalImageGenConfigsAtom); - const { data: imageUserConfigs, isLoading: imageUserLoading } = useAtomValue(imageGenConfigsAtom); - const { data: visionGlobalConfigs, isLoading: visionGlobalLoading } = useAtomValue( - globalVisionLLMConfigsAtom - ); - const { data: visionUserConfigs, isLoading: visionUserLoading } = - useAtomValue(visionLLMConfigsAtom); - - // Pending image attachments on the composer. Used to surface an - // amber "No image" hint on chat models the catalog reports as - // non-vision (`supports_image_input=false`) when the next message - // will carry an image. The hint is purely advisory: selection, - // focus, and click handling are unaffected. The backend's safety - // net (`is_known_text_only_chat_model`) is the actual block, and - // it only fires when LiteLLM *explicitly* marks a model as - // text-only — so a model that's secretly capable but hasn't been - // annotated will still flow through to the provider. - const pendingUserImageUrls = useAtomValue(pendingUserImageDataUrlsAtom); - const hasPendingImages = pendingUserImageUrls.length > 0; - - const isLoading = - llmUserLoading || - llmGlobalLoading || - prefsLoading || - imageGlobalLoading || - imageUserLoading || - visionGlobalLoading || - visionUserLoading; - - // ─── Current selected configs ─── - const currentLLMConfig = useMemo(() => { - if (!preferences) return null; - const id = preferences.agent_llm_id; - if (id === null || id === undefined) return null; - if (id <= 0) return llmGlobalConfigs?.find((c) => c.id === id) ?? null; - return llmUserConfigs?.find((c) => c.id === id) ?? null; - }, [preferences, llmGlobalConfigs, llmUserConfigs]); - - const isLLMAutoMode = - currentLLMConfig && "is_auto_mode" in currentLLMConfig && currentLLMConfig.is_auto_mode; - - const currentImageConfig = useMemo(() => { - if (!preferences) return null; - const id = preferences.image_generation_config_id; - if (id === null || id === undefined) return null; - return ( - imageGlobalConfigs?.find((c) => c.id === id) ?? - imageUserConfigs?.find((c) => c.id === id) ?? - null + const chatModels = useMemo(() => { + const normalized = search.trim().toLowerCase(); + const models = flattenChatModels([...globalConnections, ...connections]); + if (!normalized) return models; + return models.filter((model) => + [modelName(model), model.model_id, model.connectionLabel] + .join(" ") + .toLowerCase() + .includes(normalized) ); - }, [preferences, imageGlobalConfigs, imageUserConfigs]); + }, [globalConnections, connections, search]); - const isImageAutoMode = - currentImageConfig && "is_auto_mode" in currentImageConfig && currentImageConfig.is_auto_mode; + const selected = chatModels.find((model) => model.id === roles?.chat_model_id); + const groups = groupedModels(chatModels); + const loading = globalLoading || connectionsLoading; - const currentVisionConfig = useMemo(() => { - if (!preferences) return null; - const id = preferences.vision_llm_config_id; - if (id === null || id === undefined) return null; - return ( - visionGlobalConfigs?.find((c) => c.id === id) ?? - visionUserConfigs?.find((c) => c.id === id) ?? - null - ); - }, [preferences, visionGlobalConfigs, visionUserConfigs]); + function selectModel(modelId: number) { + updateRoles.mutate({ chat_model_id: modelId }); + setOpen(false); + } - const isVisionAutoMode = - currentVisionConfig && - "is_auto_mode" in currentVisionConfig && - currentVisionConfig.is_auto_mode; - - // ─── Filtered configs (separate global / user for section headers) ─── - const filteredLLMGlobal = useMemo( - () => filterAndScore(llmGlobalConfigs ?? [], selectedProvider, searchQuery), - [llmGlobalConfigs, selectedProvider, searchQuery] - ); - const filteredLLMUser = useMemo( - () => filterAndScore(llmUserConfigs ?? [], selectedProvider, searchQuery), - [llmUserConfigs, selectedProvider, searchQuery] - ); - const filteredImageGlobal = useMemo( - () => filterAndScore(imageGlobalConfigs ?? [], selectedProvider, searchQuery), - [imageGlobalConfigs, selectedProvider, searchQuery] - ); - const filteredImageUser = useMemo( - () => filterAndScore(imageUserConfigs ?? [], selectedProvider, searchQuery), - [imageUserConfigs, selectedProvider, searchQuery] - ); - const filteredVisionGlobal = useMemo( - () => filterAndScore(visionGlobalConfigs ?? [], selectedProvider, searchQuery), - [visionGlobalConfigs, selectedProvider, searchQuery] - ); - const filteredVisionUser = useMemo( - () => filterAndScore(visionUserConfigs ?? [], selectedProvider, searchQuery), - [visionUserConfigs, selectedProvider, searchQuery] - ); - - // Combined display list for keyboard navigation - const currentDisplayItems: DisplayItem[] = useMemo(() => { - const toItems = (configs: ConfigBase[], isGlobal: boolean): DisplayItem[] => - configs.map((c) => ({ - config: c as ConfigBase & Record, - isGlobal, - isAutoMode: - isGlobal && "is_auto_mode" in c && !!(c as Record).is_auto_mode, - })); - - const sortGlobalItems = (items: DisplayItem[]): DisplayItem[] => - [...items].sort((a, b) => { - if (a.isAutoMode !== b.isAutoMode) return a.isAutoMode ? -1 : 1; - const aPremium = !!(a.config as Record).is_premium; - const bPremium = !!(b.config as Record).is_premium; - if (aPremium !== bPremium) return aPremium ? 1 : -1; - return 0; - }); - - switch (activeTab) { - case "llm": - return [ - ...sortGlobalItems(toItems(filteredLLMGlobal, true)), - ...toItems(filteredLLMUser, false), - ]; - case "image": - return [ - ...sortGlobalItems(toItems(filteredImageGlobal, true)), - ...toItems(filteredImageUser, false), - ]; - case "vision": - return [ - ...sortGlobalItems(toItems(filteredVisionGlobal, true)), - ...toItems(filteredVisionUser, false), - ]; - } - }, [ - activeTab, - filteredLLMGlobal, - filteredLLMUser, - filteredImageGlobal, - filteredImageUser, - filteredVisionGlobal, - filteredVisionUser, - ]); - - // ─── Provider sidebar data ─── - // Collect which providers actually have configured models for the active tab - const configuredProviderSet = useMemo(() => { - const configs = - activeTab === "llm" - ? [...(llmGlobalConfigs ?? []), ...(llmUserConfigs ?? [])] - : activeTab === "image" - ? [...(imageGlobalConfigs ?? []), ...(imageUserConfigs ?? [])] - : [...(visionGlobalConfigs ?? []), ...(visionUserConfigs ?? [])]; - const set = new Set(); - for (const c of configs) { - if (c.provider) set.add(c.provider.toUpperCase()); - } - return set; - }, [ - activeTab, - llmGlobalConfigs, - llmUserConfigs, - imageGlobalConfigs, - imageUserConfigs, - visionGlobalConfigs, - visionUserConfigs, - ]); - - // Show only providers valid for the active tab; configured ones first - const activeProviders = useMemo(() => { - const tabKeys = PROVIDER_KEYS_BY_TAB[activeTab] ?? LLM_PROVIDER_KEYS; - const configured = tabKeys.filter((p) => configuredProviderSet.has(p)); - const unconfigured = tabKeys.filter((p) => !configuredProviderSet.has(p)); - return ["all", ...configured, ...unconfigured]; - }, [activeTab, configuredProviderSet]); - - const providerModelCounts = useMemo(() => { - const allConfigs = - activeTab === "llm" - ? [...(llmGlobalConfigs ?? []), ...(llmUserConfigs ?? [])] - : activeTab === "image" - ? [...(imageGlobalConfigs ?? []), ...(imageUserConfigs ?? [])] - : [...(visionGlobalConfigs ?? []), ...(visionUserConfigs ?? [])]; - const counts: Record = { all: allConfigs.length }; - for (const c of allConfigs) { - const p = c.provider.toUpperCase(); - counts[p] = (counts[p] || 0) + 1; - } - return counts; - }, [ - activeTab, - llmGlobalConfigs, - llmUserConfigs, - imageGlobalConfigs, - imageUserConfigs, - visionGlobalConfigs, - visionUserConfigs, - ]); - - // ─── Selection handlers ─── - const handleSelectLLM = useCallback( - async (config: NewLLMConfigPublic | GlobalNewLLMConfig) => { - if (currentLLMConfig?.id === config.id) { - setOpen(false); - return; - } - if (!searchSpaceId) { - toast.error("No search space selected"); - return; - } - try { - await updatePreferences({ - search_space_id: Number(searchSpaceId), - data: { agent_llm_id: config.id }, - }); - toast.success(`Switched to ${config.name}`); - setOpen(false); - } catch { - toast.error("Failed to switch model"); - } - }, - [currentLLMConfig, searchSpaceId, updatePreferences] - ); - - const handleSelectImage = useCallback( - async (configId: number) => { - if (currentImageConfig?.id === configId) { - setOpen(false); - return; - } - if (!searchSpaceId) { - toast.error("No search space selected"); - return; - } - try { - await updatePreferences({ - search_space_id: Number(searchSpaceId), - data: { image_generation_config_id: configId }, - }); - toast.success("Image model updated"); - setOpen(false); - } catch { - toast.error("Failed to switch image model"); - } - }, - [currentImageConfig, searchSpaceId, updatePreferences] - ); - - const handleSelectVision = useCallback( - async (configId: number) => { - if (currentVisionConfig?.id === configId) { - setOpen(false); - return; - } - if (!searchSpaceId) { - toast.error("No search space selected"); - return; - } - try { - await updatePreferences({ - search_space_id: Number(searchSpaceId), - data: { vision_llm_config_id: configId }, - }); - toast.success("Vision model updated"); - setOpen(false); - } catch { - toast.error("Failed to switch vision model"); - } - }, - [currentVisionConfig, searchSpaceId, updatePreferences] - ); - - const handleSelectItem = useCallback( - (item: DisplayItem) => { - switch (activeTab) { - case "llm": - handleSelectLLM(item.config as NewLLMConfigPublic | GlobalNewLLMConfig); - break; - case "image": - handleSelectImage(item.config.id); - break; - case "vision": - handleSelectVision(item.config.id); - break; - } - }, - [activeTab, handleSelectLLM, handleSelectImage, handleSelectVision] - ); - - const handleEditItem = useCallback( - (e: React.MouseEvent, item: DisplayItem) => { - e.stopPropagation(); - setOpen(false); - switch (activeTab) { - case "llm": - onEditLLM(item.config as NewLLMConfigPublic | GlobalNewLLMConfig, item.isGlobal); - break; - case "image": - onEditImage?.(item.config as ImageGenerationConfig | GlobalImageGenConfig, item.isGlobal); - break; - case "vision": - onEditVision?.(item.config as VisionLLMConfig | GlobalVisionLLMConfig, item.isGlobal); - break; - } - }, - [activeTab, onEditLLM, onEditImage, onEditVision] - ); - - // ─── Keyboard navigation ─── - // biome-ignore lint/correctness/useExhaustiveDependencies: searchQuery and selectedProvider are intentional triggers to reset focus - useEffect(() => { - setFocusedIndex(-1); - }, [searchQuery, selectedProvider]); - - useEffect(() => { - if (focusedIndex < 0 || !modelListRef.current) return; - const items = modelListRef.current.querySelectorAll("[data-model-index]"); - items[focusedIndex]?.scrollIntoView({ - block: "nearest", - behavior: "smooth", - }); - }, [focusedIndex]); - - const handleKeyDown = useCallback( - (e: React.KeyboardEvent) => { - const count = currentDisplayItems.length; - - // Arrow Left/Right cycle provider filters - if (e.key === "ArrowLeft" || e.key === "ArrowRight") { - e.preventDefault(); - const providers = activeProviders; - const idx = providers.indexOf(selectedProvider); - let next: number; - if (e.key === "ArrowLeft") { - next = idx > 0 ? idx - 1 : providers.length - 1; - } else { - next = idx < providers.length - 1 ? idx + 1 : 0; - } - setSelectedProvider(providers[next]); - if (providerSidebarRef.current) { - const buttons = providerSidebarRef.current.querySelectorAll("button"); - buttons[next]?.scrollIntoView({ - block: "nearest", - inline: "nearest", - behavior: "smooth", - }); - } - return; - } - - if (count === 0) return; - - switch (e.key) { - case "ArrowDown": - e.preventDefault(); - setFocusedIndex((prev) => (prev < count - 1 ? prev + 1 : 0)); - break; - case "ArrowUp": - e.preventDefault(); - setFocusedIndex((prev) => (prev > 0 ? prev - 1 : count - 1)); - break; - case "Enter": - e.preventDefault(); - if (focusedIndex >= 0 && focusedIndex < count) { - handleSelectItem(currentDisplayItems[focusedIndex]); - } - break; - case "Home": - e.preventDefault(); - setFocusedIndex(0); - break; - case "End": - e.preventDefault(); - setFocusedIndex(count - 1); - break; - } - }, - [currentDisplayItems, focusedIndex, activeProviders, selectedProvider, handleSelectItem] - ); - - // ─── Render: Provider sidebar ─── - const renderProviderSidebar = () => { - const configuredCount = configuredProviderSet.size; - - return ( -
- {!isMobile && ( -
- -
- )} - {isMobile && ( -
- -
- )} -
+
+
+ + setSearch(event.target.value)} + placeholder="Search chat models..." + className="pl-9" + /> +
+
+
+ - - - {isAll ? "All Models" : formatProviderName(provider)} - {isConfigured ? ` (${count})` : " (not configured)"} - - - - ); - })} -
- {!isMobile && ( -
- -
- )} - {isMobile && ( -
- -
- )} -
- ); - }; - - // ─── Render: Model card ─── - const getSelectedId = () => { - switch (activeTab) { - case "llm": - return currentLLMConfig?.id; - case "image": - return currentImageConfig?.id; - case "vision": - return currentVisionConfig?.id; - } - }; - - const renderModelCard = (item: DisplayItem, index: number) => { - const { config, isAutoMode } = item; - const isSelected = getSelectedId() === config.id; - const isFocused = focusedIndex === index; - const hasCitations = "citations_enabled" in config && !!config.citations_enabled; - const hasPremiumStatus = "is_premium" in config && !isAutoMode; - const isPremium = hasPremiumStatus && !!(config as Record).is_premium; - // Chat-tab only: surface an amber "No image" hint when the - // composer carries images and the catalog reports the model as - // non-vision. This is purely advisory — selection is *not* - // blocked. The backend's narrow safety net - // (`is_known_text_only_chat_model`) is the source of truth for - // rejecting image turns, and it only fires when LiteLLM - // explicitly marks the model as text-only. A model surfaced as - // `supports_image_input=false` here may still be capable in - // practice (unknown / unmapped LiteLLM entry), so we let the - // user pick it and the provider response decide. - const isImageIncompatibleChatModel = - activeTab === "llm" && - hasPendingImages && - "supports_image_input" in config && - (config as Record).supports_image_input === false; - - return ( -
handleSelectItem(item)} - onKeyDown={ - isMobile - ? undefined - : (e) => { - if (e.key === "Enter" || e.key === " ") { - e.preventDefault(); - handleSelectItem(item); - } - } - } - onMouseEnter={() => setFocusedIndex(index)} - className={cn( - "group flex items-center gap-2.5 px-3 py-2 rounded-xl", - "transition-colors duration-150 mx-2 cursor-pointer", - "hover:bg-accent hover:text-accent-foreground", - isFocused && "bg-accent text-accent-foreground", - isSelected && "bg-accent text-accent-foreground" - )} - > - {/* Provider icon */} -
- {getProviderIcon(config.provider as string, { - isAutoMode, - className: "size-5", - })} -
- - {/* Model info */} -
-
- - {isAutoMode && ( - - Recommended - - )} - {isImageIncompatibleChatModel && ( - - No image - - )} -
- {isAutoMode ? ( -
- Auto Mode +
+
+
- ) : ( - (hasPremiumStatus || hasCitations) && ( -
- {hasPremiumStatus && ( - - {isPremium ? "Premium" : "Free"} - - )} - {hasCitations && ( - - Citations - - )} +
+
Auto
+
Use the hosted/global router
+
+
+ {(roles?.chat_model_id ?? 0) === 0 ? : null} + + {loading ? ( +
+ +
+ ) : Object.keys(groups).length === 0 ? ( +
+ No enabled chat models. Add or enable models in Settings. +
+ ) : ( + Object.entries(groups).map(([connection, models]) => ( +
+
+ {connection}
- ) - )} -
- - {/* Actions */} -
- {!isAutoMode && ( - - )} - {isSelected && ( -
- -
- )} -
-
- ); - }; - - // ─── Render: Full content ─── - const renderContent = () => { - const globalItems = currentDisplayItems.filter((i) => i.isGlobal); - const userItems = currentDisplayItems.filter((i) => !i.isGlobal); - const globalStartIdx = 0; - const userStartIdx = globalItems.length; - - const addHandler = - activeTab === "llm" ? onAddNewLLM : activeTab === "image" ? onAddNewImage : onAddNewVision; - const addLabel = - activeTab === "llm" - ? "Add Model" - : activeTab === "image" - ? "Add Image Model" - : "Add Vision Model"; - - return ( -
- {/* Tab header */} -
-
- {( - [ - { - value: "llm" as const, - icon: Zap, - label: "LLM", - }, - { - value: "image" as const, - icon: ImageIcon, - label: "Image", - }, - { - value: "vision" as const, - icon: ScanEye, - label: "Vision", - }, - ] as const - ).map(({ value, icon: Icon, label }) => ( - - ))} -
-
- - {/* Two-pane layout */} -
- {/* Provider sidebar */} - {renderProviderSidebar()} - - {/* Main content */} -
- {/* Search */} -
- - setSearchQuery(e.target.value)} - onKeyDown={isMobile ? undefined : handleKeyDown} - role="combobox" - aria-expanded={true} - aria-controls="model-selector-list" - className={cn( - "w-full pl-8 pr-3 py-2.5 text-sm bg-transparent", - "focus:outline-none", - "placeholder:text-muted-foreground" - )} - /> -
- - {/* Provider header when filtered */} - {selectedProvider !== "all" && ( -
- {getProviderIcon(selectedProvider, { - className: "size-4", - })} - {formatProviderName(selectedProvider)} - - {configuredProviderSet.has(selectedProvider) - ? `${providerModelCounts[selectedProvider] || 0} models` - : "Not configured"} - -
- )} - - {/* Model list */} -
- {currentDisplayItems.length === 0 ? ( -
- {selectedProvider !== "all" && !configuredProviderSet.has(selectedProvider) ? ( - <> -
- {getProviderIcon(selectedProvider, { - className: "size-10", - })} -
-

- No {formatProviderName(selectedProvider)} models configured -

-

- Add a model with this provider to get started -

- {addHandler && ( - - )} - - ) : searchQuery ? ( - <> - -

No models found

-

- Try a different search term -

- - ) : ( - <> -

- No models configured -

-

- Configure models in your search space settings -

- - )} -
- ) : ( - <> - {globalItems.length > 0 && ( - <> -
- Global Models -
- {globalItems.map((item, i) => renderModelCard(item, globalStartIdx + i))} - - )} - {globalItems.length > 0 && userItems.length > 0 && ( -
- )} - {userItems.length > 0 && ( - <> -
- Your Configurations -
- {userItems.map((item, i) => renderModelCard(item, userStartIdx + i))} - - )} - - )} -
- - {/* Add model button */} - {addHandler && ( -
- -
- )} -
-
+
+
+ {getProviderIcon(model.provider, { className: "size-4 shrink-0" })} + {modelName(model)} +
+
{model.model_id}
+
+
+ {!model.capabilities?.vision ? ( + + No image + + ) : null} + {roles?.chat_model_id === model.id ? : null} +
+ + ))} +
+ )) + )}
- ); - }; +
+ +
+
+ ); - // ─── Trigger button ─── - const triggerButton = ( + const trigger = ( ); - // ─── Shell: Drawer on mobile, Popover on desktop ─── if (isMobile) { return ( - - {triggerButton} - - - - Select Model + + {trigger} + + + Select Chat Model -
{renderContent()}
+ {content}
); } return ( - - {triggerButton} - e.preventDefault()} - > - {renderContent()} + + {trigger} + + {content} ); diff --git a/surfsense_web/content/docs/how-to/ollama.mdx b/surfsense_web/content/docs/how-to/ollama.mdx index 48b231705..f5ec09e1b 100644 --- a/surfsense_web/content/docs/how-to/ollama.mdx +++ b/surfsense_web/content/docs/how-to/ollama.mdx @@ -22,7 +22,7 @@ If SurfSense runs in Docker, do not use `localhost` unless Ollama is in the same ## 2) Add Ollama in SurfSense -Go to **Search Space Settings -> Agent Models -> Add Model** and set: +Go to **Search Space Settings -> Models -> Add Model** and set: - Provider: `OLLAMA` - Model name: your model tag, for example `llama3.2` or `qwen3:8b` From 6c352021a093bd6b3b69307b093a934c83d9255b Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Wed, 10 Jun 2026 21:50:22 +0530 Subject: [PATCH 11/59] chore(i18n): add model connection copy --- surfsense_web/messages/en.json | 3 ++- surfsense_web/messages/es.json | 2 +- surfsense_web/messages/hi.json | 2 +- surfsense_web/messages/pt.json | 2 +- surfsense_web/messages/zh.json | 2 +- 5 files changed, 6 insertions(+), 5 deletions(-) diff --git a/surfsense_web/messages/en.json b/surfsense_web/messages/en.json index a13942e64..80badd84f 100644 --- a/surfsense_web/messages/en.json +++ b/surfsense_web/messages/en.json @@ -743,7 +743,8 @@ "back_to_app": "Back to app", "nav_general": "General", "nav_general_desc": "Name, description & basic info", - "nav_agent_models": "Agent Models", + "nav_models": "Models", + "nav_agent_models": "Chat Models", "nav_agent_models_desc": "Models with prompts & citations", "nav_role_assignments": "Role Assignments", "nav_role_assignments_desc": "Assign configs to agent roles", diff --git a/surfsense_web/messages/es.json b/surfsense_web/messages/es.json index 33ae79c52..d6225f532 100644 --- a/surfsense_web/messages/es.json +++ b/surfsense_web/messages/es.json @@ -743,7 +743,7 @@ "back_to_app": "Volver a la app", "nav_general": "General", "nav_general_desc": "Nombre, descripción e información básica", - "nav_agent_models": "Modelos de agente", + "nav_agent_models": "Modelos de chat", "nav_agent_models_desc": "Modelos LLM con prompts y citas", "nav_role_assignments": "Asignaciones de roles", "nav_role_assignments_desc": "Asignar configuraciones a roles de agente", diff --git a/surfsense_web/messages/hi.json b/surfsense_web/messages/hi.json index 7a26d0c1d..3cb3ad41a 100644 --- a/surfsense_web/messages/hi.json +++ b/surfsense_web/messages/hi.json @@ -743,7 +743,7 @@ "back_to_app": "ऐप पर वापस जाएं", "nav_general": "सामान्य", "nav_general_desc": "नाम, विवरण और बुनियादी जानकारी", - "nav_agent_models": "एजेंट मॉडल", + "nav_agent_models": "चैट मॉडल", "nav_agent_models_desc": "प्रॉम्प्ट और उद्धरण के साथ LLM मॉडल", "nav_role_assignments": "भूमिका असाइनमेंट", "nav_role_assignments_desc": "एजेंट भूमिकाओं को कॉन्फ़िगरेशन असाइन करें", diff --git a/surfsense_web/messages/pt.json b/surfsense_web/messages/pt.json index 61c22e086..96bfc096d 100644 --- a/surfsense_web/messages/pt.json +++ b/surfsense_web/messages/pt.json @@ -743,7 +743,7 @@ "back_to_app": "Voltar ao app", "nav_general": "Geral", "nav_general_desc": "Nome, descrição e informações básicas", - "nav_agent_models": "Modelos do agente", + "nav_agent_models": "Modelos de chat", "nav_agent_models_desc": "Modelos LLM com prompts e citações", "nav_role_assignments": "Atribuições de funções", "nav_role_assignments_desc": "Atribuir configurações a funções do agente", diff --git a/surfsense_web/messages/zh.json b/surfsense_web/messages/zh.json index 7d0419cbd..03b217f12 100644 --- a/surfsense_web/messages/zh.json +++ b/surfsense_web/messages/zh.json @@ -727,7 +727,7 @@ "back_to_app": "返回应用", "nav_general": "常规", "nav_general_desc": "名称、描述和基本信息", - "nav_agent_models": "代理模型", + "nav_agent_models": "聊天模型", "nav_agent_models_desc": "LLM 模型配置提示词和引用", "nav_role_assignments": "角色分配", "nav_role_assignments_desc": "为代理角色分配配置", From 85114d2a0e0ed04c67aee49fba12d341cad02322 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Wed, 10 Jun 2026 21:50:42 +0530 Subject: [PATCH 12/59] refactor(chat): rename image generation config parameters for clarity --- .../multi_agent_chat/main_agent/runtime/agent_cache.py | 4 ++-- .../chat/multi_agent_chat/main_agent/runtime/factory.py | 8 ++++---- surfsense_backend/app/agents/chat/runtime/llm_config.py | 4 ++-- surfsense_backend/app/agents/podcaster/nodes.py | 2 +- 4 files changed, 9 insertions(+), 9 deletions(-) diff --git a/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/runtime/agent_cache.py b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/runtime/agent_cache.py index 6ac22e575..2d3599de0 100644 --- a/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/runtime/agent_cache.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/runtime/agent_cache.py @@ -57,7 +57,7 @@ async def build_agent_with_cache( mcp_tools_by_agent: dict[str, list[BaseTool]], disabled_tools: list[str] | None, config_id: str | None, - image_generation_config_id_override: int | None = None, + image_gen_model_id_override: int | None = None, ) -> Any: """Compile the multi-agent graph, serving from cache when key components are stable.""" @@ -121,7 +121,7 @@ async def build_agent_with_cache( # Bound into the generate_image subagent tool at construction time, so it # must key the compiled-agent cache to avoid leaking one automation's # image model into another with the same config_id/search_space. - image_generation_config_id_override, + image_gen_model_id_override, ) return await get_cache().get_or_build(cache_key, builder=_build) diff --git a/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/runtime/factory.py b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/runtime/factory.py index adb1bc1ed..10a734192 100644 --- a/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/runtime/factory.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/runtime/factory.py @@ -72,11 +72,11 @@ async def create_multi_agent_chat_deep_agent( mentioned_document_ids: list[int] | None = None, anon_session_id: str | None = None, filesystem_selection: FilesystemSelection | None = None, - image_generation_config_id: int | None = None, + image_gen_model_id: int | None = None, ): """Deep agent with SurfSense tools/middleware; registry route subagents behind ``task`` when enabled. - ``image_generation_config_id`` overrides the search space's image model for + ``image_gen_model_id`` overrides the search space's image model for this invocation (used by automations to run on their captured model). When ``None``, the ``generate_image`` tool resolves the live search-space pref. """ @@ -147,7 +147,7 @@ async def create_multi_agent_chat_deep_agent( "llm": llm, # Per-invocation image model override (automations run on their captured # model). Reaches the generate_image subagent tool via subagent_dependencies. - "image_generation_config_id_override": image_generation_config_id, + "image_gen_model_id_override": image_gen_model_id, } _t0 = time.perf_counter() @@ -303,7 +303,7 @@ async def create_multi_agent_chat_deep_agent( mcp_tools_by_agent=mcp_tools_by_agent, disabled_tools=disabled_tools, config_id=config_id, - image_generation_config_id_override=image_generation_config_id, + image_gen_model_id_override=image_gen_model_id, ) _perf_log.info( "[create_agent] Middleware stack + graph compiled in %.3fs", diff --git a/surfsense_backend/app/agents/chat/runtime/llm_config.py b/surfsense_backend/app/agents/chat/runtime/llm_config.py index aad432edb..03d7f548e 100644 --- a/surfsense_backend/app/agents/chat/runtime/llm_config.py +++ b/surfsense_backend/app/agents/chat/runtime/llm_config.py @@ -351,7 +351,7 @@ async def load_agent_llm_config_for_search_space( session: AsyncSession, search_space_id: int, ) -> "AgentConfig | None": - """Load the agent LLM config for a search space via its agent_llm_id. + """Load the chat model config for a search space via its agent_llm_id. Positive id -> DB; negative -> YAML; None -> first global config (-1). """ @@ -372,7 +372,7 @@ async def load_agent_llm_config_for_search_space( ) return await load_agent_config(session, config_id, search_space_id) except Exception as e: - print(f"Error loading agent LLM config for search space {search_space_id}: {e}") + print(f"Error loading chat model config for search space {search_space_id}: {e}") return None diff --git a/surfsense_backend/app/agents/podcaster/nodes.py b/surfsense_backend/app/agents/podcaster/nodes.py index d1f140a44..0d54cbe45 100644 --- a/surfsense_backend/app/agents/podcaster/nodes.py +++ b/surfsense_backend/app/agents/podcaster/nodes.py @@ -31,7 +31,7 @@ async def create_podcast_transcript( llm = await get_agent_llm(state.db_session, search_space_id) if not llm: - error_message = f"No agent LLM configured for search space {search_space_id}" + error_message = f"No chat model configured for search space {search_space_id}" print(error_message) raise RuntimeError(error_message) From 780e24213240310d52071c89c528ad1643d86071 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Thu, 11 Jun 2026 00:11:53 +0530 Subject: [PATCH 13/59] feat(model-connections): implement manual model addition and enhance model discovery --- .../app/routes/model_connections_routes.py | 38 +++ surfsense_backend/app/schemas/__init__.py | 1 + .../app/schemas/model_connections.py | 11 + .../app/services/model_connection_service.py | 28 +- .../model-connections-mutation.atoms.ts | 28 +- .../settings/model-connections-settings.tsx | 312 ++++++++++++------ .../types/model-connections.types.ts | 6 + .../lib/apis/model-connections-api.service.ts | 12 + 8 files changed, 335 insertions(+), 101 deletions(-) diff --git a/surfsense_backend/app/routes/model_connections_routes.py b/surfsense_backend/app/routes/model_connections_routes.py index 69910183d..6d19a5ed1 100644 --- a/surfsense_backend/app/routes/model_connections_routes.py +++ b/surfsense_backend/app/routes/model_connections_routes.py @@ -10,6 +10,7 @@ from app.db import ( Connection, ConnectionScope, Model, + ModelSource, Permission, SearchSpace, User, @@ -19,6 +20,7 @@ from app.schemas import ( ConnectionCreate, ConnectionRead, ConnectionUpdate, + ModelCreate, ModelRead, ModelRolesRead, ModelRolesUpdate, @@ -26,6 +28,7 @@ from app.schemas import ( VerifyConnectionResponse, ) from app.services.model_connection_service import ( + derive_capabilities, discover_models, persist_verification, test_model, @@ -254,6 +257,41 @@ async def discover_connection_models( return [_model_read(model) for model in conn.models] +@router.post("/model-connections/{connection_id}/models", response_model=ModelRead) +async def add_manual_model( + connection_id: int, + data: ModelCreate, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +): + conn = await _load_connection(session, connection_id) + await _assert_connection_access(session, user, conn, Permission.LLM_CONFIGS_UPDATE.value) + + model_id = data.model_id.strip() + if not model_id: + raise HTTPException(status_code=400, detail="model_id is required") + if any(existing.model_id == model_id for existing in conn.models): + raise HTTPException(status_code=400, detail="Model already exists on this connection") + + capabilities = derive_capabilities(conn, model_id) + model = Model( + connection_id=conn.id, + model_id=model_id, + display_name=data.display_name or None, + source=ModelSource.MANUAL, + capabilities=capabilities, + capabilities_declared=capabilities, + capabilities_verified={}, + capabilities_override={}, + enabled=True, + catalog={}, + ) + session.add(model) + await session.commit() + await session.refresh(model) + return _model_read(model) + + @router.put("/models/{model_id}", response_model=ModelRead) async def update_model( model_id: int, diff --git a/surfsense_backend/app/schemas/__init__.py b/surfsense_backend/app/schemas/__init__.py index c14671c99..2a06eca5c 100644 --- a/surfsense_backend/app/schemas/__init__.py +++ b/surfsense_backend/app/schemas/__init__.py @@ -48,6 +48,7 @@ from .model_connections import ( ConnectionCreate, ConnectionRead, ConnectionUpdate, + ModelCreate, ModelRead, ModelRolesRead, ModelRolesUpdate, diff --git a/surfsense_backend/app/schemas/model_connections.py b/surfsense_backend/app/schemas/model_connections.py index 731064375..ea1ec4e88 100644 --- a/surfsense_backend/app/schemas/model_connections.py +++ b/surfsense_backend/app/schemas/model_connections.py @@ -65,6 +65,17 @@ class ConnectionUpdate(BaseModel): enabled: bool | None = None +class ModelCreate(BaseModel): + """Manually register a model id on a connection. + + For providers without a usable ``/models`` endpoint (Perplexity, MiniMax, + Azure deployments, etc.) or to pin a single model from a noisy provider. + """ + + model_id: str = Field(..., max_length=255) + display_name: str | None = Field(None, max_length=255) + + class ModelUpdate(BaseModel): display_name: str | None = Field(None, max_length=255) enabled: bool | None = None diff --git a/surfsense_backend/app/services/model_connection_service.py b/surfsense_backend/app/services/model_connection_service.py index 81090acaf..c8d2e8a5a 100644 --- a/surfsense_backend/app/services/model_connection_service.py +++ b/surfsense_backend/app/services/model_connection_service.py @@ -122,6 +122,17 @@ def _litellm_capabilities(model_string: str, model_id: str) -> dict[str, bool]: return capabilities +def _allowlist(conn: Connection) -> set[str]: + """Per-connection model-id allowlist stored in ``extra.model_ids``. + + Empty/absent means "no restriction" (discover everything), mirroring + OpenWebUI's behaviour. A non-empty list restricts discovery to those ids — + essential for providers like OpenRouter that expose hundreds of models. + """ + raw = (conn.extra or {}).get("model_ids") or [] + return {str(item).strip() for item in raw if str(item).strip()} + + def derive_capabilities(conn: Connection, model_id: str, metadata: dict | None = None) -> dict[str, bool]: metadata = metadata or {} if conn.protocol == ConnectionProtocol.OLLAMA: @@ -140,13 +151,15 @@ def derive_capabilities(conn: Connection, model_id: str, metadata: dict | None = async def discover_models(conn: Connection) -> list[dict[str, Any]]: + allowlist = _allowlist(conn) + if conn.protocol == ConnectionProtocol.OLLAMA: url = f"{conn.base_url.rstrip('/')}/api/tags" async with httpx.AsyncClient(timeout=DISCOVERY_TIMEOUT_SECONDS) as client: response = await client.get(url, headers=_auth_headers(conn)) response.raise_for_status() models = response.json().get("models", []) - return [ + results = [ { "model_id": item.get("model") or item.get("name"), "display_name": item.get("name") or item.get("model"), @@ -157,14 +170,13 @@ async def discover_models(conn: Connection) -> list[dict[str, Any]]: for item in models if item.get("model") or item.get("name") ] - - if conn.protocol == ConnectionProtocol.OPENAI_COMPATIBLE: + elif conn.protocol == ConnectionProtocol.OPENAI_COMPATIBLE: url = f"{ensure_v1(conn.base_url)}/models" async with httpx.AsyncClient(timeout=DISCOVERY_TIMEOUT_SECONDS) as client: response = await client.get(url, headers=_auth_headers(conn)) response.raise_for_status() models = response.json().get("data", []) - return [ + results = [ { "model_id": item.get("id"), "display_name": item.get("name") or item.get("id"), @@ -175,9 +187,13 @@ async def discover_models(conn: Connection) -> list[dict[str, Any]]: for item in models if item.get("id") ] + else: + # Native providers rely on curated/global catalog entries or manual rows. + return [] - # Native providers rely on curated/global catalog entries or manual rows. - return [] + if allowlist: + results = [item for item in results if item["model_id"] in allowlist] + return results async def test_model(conn: Connection, model: Model) -> VerifyResult: diff --git a/surfsense_web/atoms/model-connections/model-connections-mutation.atoms.ts b/surfsense_web/atoms/model-connections/model-connections-mutation.atoms.ts index 7d58a402c..612216bf2 100644 --- a/surfsense_web/atoms/model-connections/model-connections-mutation.atoms.ts +++ b/surfsense_web/atoms/model-connections/model-connections-mutation.atoms.ts @@ -3,6 +3,7 @@ import { toast } from "sonner"; import type { ConnectionCreateRequest, ConnectionUpdateRequest, + ModelCreateRequest, ModelRoles, ModelUpdateRequest, } from "@/contracts/types/model-connections.types"; @@ -67,8 +68,17 @@ export const verifyModelConnectionMutationAtom = atomWithMutation((get) => { mutationKey: ["model-connections", "verify"], mutationFn: (id: number) => modelConnectionsApiService.verifyConnection(id), onSuccess: (result) => { - if (result.ok) toast.success("Connection verified"); - else toast.error(result.message || "Connection failed"); + if (result.ok) { + toast.success("Connection verified"); + } else { + // Non-fatal: many providers lack a /models endpoint yet still serve + // chat. Guide the user to add model IDs manually instead of alarming. + toast.warning( + result.message + ? `${result.message} Chat may still work — add model IDs manually.` + : "Couldn't list models. Chat may still work — add model IDs manually." + ); + } invalidateModelConnections(searchSpaceId); }, onError: (error: Error) => toast.error(error.message || "Failed to verify connection"), @@ -88,6 +98,20 @@ export const discoverConnectionModelsMutationAtom = atomWithMutation((get) => { }; }); +export const addManualModelMutationAtom = atomWithMutation((get) => { + const searchSpaceId = Number(get(activeSearchSpaceIdAtom)); + return { + mutationKey: ["models", "add-manual"], + mutationFn: ({ connectionId, data }: { connectionId: number; data: ModelCreateRequest }) => + modelConnectionsApiService.addManualModel(connectionId, data), + onSuccess: () => { + toast.success("Model added"); + invalidateModelConnections(searchSpaceId); + }, + onError: (error: Error) => toast.error(error.message || "Failed to add model"), + }; +}); + export const updateModelMutationAtom = atomWithMutation((get) => { const searchSpaceId = Number(get(activeSearchSpaceIdAtom)); return { diff --git a/surfsense_web/components/settings/model-connections-settings.tsx b/surfsense_web/components/settings/model-connections-settings.tsx index 5fa4cccf7..e89fc3278 100644 --- a/surfsense_web/components/settings/model-connections-settings.tsx +++ b/surfsense_web/components/settings/model-connections-settings.tsx @@ -1,12 +1,14 @@ "use client"; import { useAtom, useAtomValue } from "jotai"; -import { CheckCircle2, PlugZap, RefreshCcw, XCircle } from "lucide-react"; +import { CheckCircle2, PlugZap, Plus, RefreshCcw, XCircle } from "lucide-react"; import { useMemo, useState } from "react"; import { + addManualModelMutationAtom, createModelConnectionMutationAtom, discoverConnectionModelsMutationAtom, testModelMutationAtom, + updateModelConnectionMutationAtom, updateModelMutationAtom, updateModelRolesMutationAtom, verifyModelConnectionMutationAtom, @@ -46,9 +48,16 @@ type Preset = { }; const PRESETS: Preset[] = [ + { id: "custom", label: "OpenAI-compatible (any URL)", protocol: "OPENAI_COMPATIBLE" }, { id: "openai", label: "OpenAI", protocol: "NATIVE", nativeProvider: "OPENAI" }, { id: "anthropic", label: "Anthropic", protocol: "NATIVE", nativeProvider: "ANTHROPIC" }, - { id: "openrouter", label: "OpenRouter", protocol: "NATIVE", nativeProvider: "OPENROUTER" }, + { + id: "openrouter", + label: "OpenRouter", + protocol: "NATIVE", + nativeProvider: "OPENROUTER", + baseUrl: "https://openrouter.ai/api/v1", + }, { id: "ollama", label: "Ollama", @@ -86,6 +95,22 @@ const PRESETS: Preset[] = [ }, ]; +// Free-text URL hints (datalist), mirroring OpenWebUI. These never restrict +// what the user can type — any OpenAI-compatible endpoint works. +const URL_SUGGESTIONS = [ + "https://api.openai.com/v1", + "https://api.anthropic.com/v1", + "https://openrouter.ai/api/v1", + "https://generativelanguage.googleapis.com/v1beta/openai", + "https://api.groq.com/openai/v1", + "https://api.mistral.ai/v1", + "https://api.deepseek.com/v1", + "https://api.x.ai/v1", + "http://host.docker.internal:11434", + "http://host.docker.internal:1234/v1", + "http://host.docker.internal:8000/v1", +]; + function modelLabel(model: ModelRead) { return model.display_name || model.model_id; } @@ -123,22 +148,183 @@ function flattenModels(connections: ConnectionRead[]) { ); } +function ConnectionCard({ connection }: { connection: ConnectionRead }) { + const verifyConnection = useAtomValue(verifyModelConnectionMutationAtom); + const discoverModels = useAtomValue(discoverConnectionModelsMutationAtom); + const updateConnection = useAtomValue(updateModelConnectionMutationAtom); + const addManualModel = useAtomValue(addManualModelMutationAtom); + const updateModel = useAtomValue(updateModelMutationAtom); + const testModel = useAtomValue(testModelMutationAtom); + + const allowlist = Array.isArray(connection.extra?.model_ids) + ? (connection.extra.model_ids as string[]) + : []; + const [allowlistText, setAllowlistText] = useState(allowlist.join(", ")); + const [manualModelId, setManualModelId] = useState(""); + + const providerLabel = connection.native_provider || connection.protocol; + const isLocal = connection.protocol === "OLLAMA" || !connection.base_url?.startsWith("https"); + + function saveAllowlist() { + const ids = allowlistText + .split(",") + .map((value) => value.trim()) + .filter(Boolean); + updateConnection.mutate({ + id: connection.id, + data: { extra: { ...(connection.extra ?? {}), model_ids: ids } }, + }); + } + + function addModel() { + const modelId = manualModelId.trim(); + if (!modelId) return; + addManualModel.mutate( + { connectionId: connection.id, data: { model_id: modelId } }, + { onSuccess: () => setManualModelId("") } + ); + } + + return ( +
+
+
+
+ {getProviderIcon(providerLabel, { className: "size-4" })} + {providerLabel} +
+
+ {connection.base_url || "Provider default endpoint"} +
+
+
+ + + +
+
+ + {connection.last_status && connection.last_status !== "OK" ? ( +

+ {connection.last_error || "Could not list models."} Chat may still work — add model + IDs manually below. +

+ ) : null} + + {!isLocal ? ( +
+ +
+ setAllowlistText(event.target.value)} + placeholder="Comma-separated, e.g. anthropic/claude-sonnet-4-5, google/gemini-2.5-pro" + /> + +
+

+ Leave empty to discover all models. Recommended for providers with large catalogs + (e.g. OpenRouter). +

+
+ ) : null} + +
+ setManualModelId(event.target.value)} + onKeyDown={(event) => { + if (event.key === "Enter") { + event.preventDefault(); + addModel(); + } + }} + placeholder="Add a model ID manually (for providers without /models)" + /> + +
+ +
+ {connection.models.map((model) => ( +
+
+
+ {getProviderIcon(providerLabel, { className: "size-4" })} + {modelLabel(model)} + {model.source === "MANUAL" ? ( + + manual + + ) : null} +
+
+ {["chat", "vision", "image_gen"] + .filter((key) => Boolean(model.capabilities?.[key])) + .join(", ") || "No verified capabilities"} +
+
+
+ + +
+
+ ))} +
+
+ ); +} + export function ModelConnectionsSettings({ searchSpaceId }: { searchSpaceId: number }) { const [{ data: globalConnections = [] }] = useAtom(globalModelConnectionsAtom); const [{ data: connections = [] }] = useAtom(modelConnectionsAtom); const [{ data: roles }] = useAtom(modelRolesAtom); const createConnection = useAtomValue(createModelConnectionMutationAtom); - const verifyConnection = useAtomValue(verifyModelConnectionMutationAtom); - const discoverModels = useAtomValue(discoverConnectionModelsMutationAtom); - const updateModel = useAtomValue(updateModelMutationAtom); - const testModel = useAtomValue(testModelMutationAtom); const updateRoles = useAtomValue(updateModelRolesMutationAtom); const visiblePresets = useMemo( () => PRESETS.filter((preset) => !(isCloud() && preset.local)), [] ); - const [presetId, setPresetId] = useState(visiblePresets[0]?.id ?? "openai"); + const [presetId, setPresetId] = useState(visiblePresets[0]?.id ?? "custom"); const preset = visiblePresets.find((item) => item.id === presetId) ?? visiblePresets[0]; const [baseUrl, setBaseUrl] = useState(preset?.baseUrl ?? ""); const [apiKey, setApiKey] = useState(""); @@ -157,16 +343,19 @@ export function ModelConnectionsSettings({ searchSpaceId }: { searchSpaceId: num function handleCreate() { if (!preset) return; - createConnection.mutate({ - protocol: preset.protocol, - native_provider: preset.nativeProvider, - base_url: baseUrl || null, - api_key: apiKey || null, - scope: "SEARCH_SPACE", - search_space_id: searchSpaceId, - extra: {}, - enabled: true, - }); + createConnection.mutate( + { + protocol: preset.protocol, + native_provider: preset.nativeProvider, + base_url: baseUrl || null, + api_key: apiKey || null, + scope: "SEARCH_SPACE", + search_space_id: searchSpaceId, + extra: {}, + enabled: true, + }, + { onSuccess: () => setApiKey("") } + ); } function renderModelOption(model: ModelRead & { connectionName: string; provider: string }) { @@ -192,7 +381,7 @@ export function ModelConnectionsSettings({ searchSpaceId }: { searchSpaceId: num
- +
- - setBaseUrl(event.target.value)} - placeholder="https://api.example.com/v1" - list="model-conn-url-suggestions" - /> - - {URL_SUGGESTIONS.map((url) => ( - + + {isNative && !showCustomEndpoint ? ( +
+
+ Uses provider default +
+ +
+ ) : ( + <> + setBaseUrl(event.target.value)} + placeholder="https://api.example.com/v1" + list="model-conn-url-suggestions" + /> + + {URL_SUGGESTIONS.map((url) => ( + + + )}
- + setApiKey(event.target.value)} - placeholder="Optional for local models" + placeholder={preset?.local ? "Optional for local models" : "API key"} type="password" />
-
@@ -434,10 +457,15 @@ export function ModelConnectionsSettings({ searchSpaceId }: { searchSpaceId: num Local URLs are tested from the backend container. Use host.docker.internal instead of localhost.

+ ) : isNative ? ( +

+ Just paste an API key — {preset?.label} routes through its native endpoint + automatically. After adding, hit Discover (or add model IDs manually). +

) : preset?.protocol === "OPENAI_COMPATIBLE" ? (

- Works with any OpenAI-compatible endpoint (OpenRouter, Together, Groq, vLLM, LM - Studio…). After adding, hit Discover to list models. + Enter any OpenAI-compatible endpoint (OpenRouter, Together, Groq, vLLM, LM Studio…). + After adding, hit Discover to list models.

) : null} From 50c816c81c0b85e3c6f28fe4719e0be85d183cac Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Thu, 11 Jun 2026 10:22:39 +0530 Subject: [PATCH 15/59] refactor(model-connections): streamline connection reading and model handling in routes --- .../app/routes/model_connections_routes.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/surfsense_backend/app/routes/model_connections_routes.py b/surfsense_backend/app/routes/model_connections_routes.py index 6d19a5ed1..5872671b1 100644 --- a/surfsense_backend/app/routes/model_connections_routes.py +++ b/surfsense_backend/app/routes/model_connections_routes.py @@ -69,7 +69,7 @@ def _connection_read(conn: Connection | dict, models: list[Model | dict] | None last_verified_at=conn.last_verified_at, last_status=conn.last_status, last_error=conn.last_error, - models=[_model_read(model) for model in (models or conn.models or [])], + models=[_model_read(model) for model in (models or [])], created_at=conn.created_at, ) @@ -144,7 +144,10 @@ async def list_connections( else: stmt = stmt.where(Connection.user_id == user.id) result = await session.execute(stmt.order_by(Connection.id)) - return [_connection_read(conn) for conn in result.scalars().all()] + return [ + _connection_read(conn, list(conn.models)) + for conn in result.scalars().all() + ] @router.post("/model-connections", response_model=ConnectionRead) @@ -173,7 +176,7 @@ async def create_connection( session.add(conn) await session.commit() await session.refresh(conn) - return _connection_read(conn) + return _connection_read(conn, []) @router.put("/model-connections/{connection_id}", response_model=ConnectionRead) @@ -188,8 +191,8 @@ async def update_connection( for key, value in data.model_dump(exclude_unset=True).items(): setattr(conn, key, value) await session.commit() - await session.refresh(conn) - return _connection_read(conn) + conn = await _load_connection(session, connection_id) + return _connection_read(conn, list(conn.models)) @router.delete("/model-connections/{connection_id}") @@ -253,7 +256,7 @@ async def discover_connection_models( } db_model.catalog = item.get("metadata") or db_model.catalog await session.commit() - await session.refresh(conn) + conn = await _load_connection(session, connection_id) return [_model_read(model) for model in conn.models] From 3f016421992919a63211df2adbdc13df39278fbc Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Thu, 11 Jun 2026 17:29:55 +0530 Subject: [PATCH 16/59] feat(model-connections): enhance model discovery with OpenAI and LiteLLM support --- .../app/services/model_connection_service.py | 93 +++++++++++++++---- .../model-connections-mutation.atoms.ts | 8 +- 2 files changed, 80 insertions(+), 21 deletions(-) diff --git a/surfsense_backend/app/services/model_connection_service.py b/surfsense_backend/app/services/model_connection_service.py index c8d2e8a5a..42a4792a4 100644 --- a/surfsense_backend/app/services/model_connection_service.py +++ b/surfsense_backend/app/services/model_connection_service.py @@ -2,6 +2,7 @@ from __future__ import annotations +import asyncio import contextlib import logging from dataclasses import dataclass @@ -12,7 +13,8 @@ import httpx import litellm from app.db import Connection, ConnectionProtocol, Model, ModelSource -from app.services.model_resolver import ensure_v1, to_litellm +from app.services.model_resolver import NATIVE_PROVIDER_PREFIX, ensure_v1, to_litellm +from app.services.provider_api_base import resolve_api_base logger = logging.getLogger(__name__) @@ -133,6 +135,63 @@ def _allowlist(conn: Connection) -> set[str]: return {str(item).strip() for item in raw if str(item).strip()} +async def _discover_openai_shaped_models(conn: Connection, base_url: str | None) -> list[dict[str, Any]]: + if not base_url: + return [] + + url = f"{ensure_v1(base_url)}/models" + async with httpx.AsyncClient(timeout=DISCOVERY_TIMEOUT_SECONDS) as client: + response = await client.get(url, headers=_auth_headers(conn)) + response.raise_for_status() + return [ + { + "model_id": item.get("id"), + "display_name": item.get("name") or item.get("id"), + "source": ModelSource.DISCOVERED, + "capabilities": derive_capabilities(conn, item.get("id"), item), + "metadata": item, + } + for item in response.json().get("data", []) + if item.get("id") + ] + + +def _litellm_valid_model_ids(provider: str, api_key: str | None) -> list[str]: + if not api_key: + return [] + + try: + models = litellm.get_valid_models( + check_provider_endpoint=True, + custom_llm_provider=provider, + api_key=api_key, + ) + except Exception as exc: + logger.warning("LiteLLM model discovery failed for provider %s: %s", provider, exc) + return [] + + provider_prefix = f"{provider}/" + return [ + model.removeprefix(provider_prefix) + for model in models + if isinstance(model, str) and model.strip() + ] + + +async def _discover_litellm_native_models(conn: Connection, provider: str) -> list[dict[str, Any]]: + model_ids = await asyncio.to_thread(_litellm_valid_model_ids, provider, conn.api_key) + return [ + { + "model_id": model_id, + "display_name": model_id, + "source": ModelSource.DISCOVERED, + "capabilities": derive_capabilities(conn, model_id), + "metadata": {}, + } + for model_id in model_ids + ] + + def derive_capabilities(conn: Connection, model_id: str, metadata: dict | None = None) -> dict[str, bool]: metadata = metadata or {} if conn.protocol == ConnectionProtocol.OLLAMA: @@ -171,25 +230,21 @@ async def discover_models(conn: Connection) -> list[dict[str, Any]]: if item.get("model") or item.get("name") ] elif conn.protocol == ConnectionProtocol.OPENAI_COMPATIBLE: - url = f"{ensure_v1(conn.base_url)}/models" - async with httpx.AsyncClient(timeout=DISCOVERY_TIMEOUT_SECONDS) as client: - response = await client.get(url, headers=_auth_headers(conn)) - response.raise_for_status() - models = response.json().get("data", []) - results = [ - { - "model_id": item.get("id"), - "display_name": item.get("name") or item.get("id"), - "source": ModelSource.DISCOVERED, - "capabilities": derive_capabilities(conn, item.get("id"), item), - "metadata": item, - } - for item in models - if item.get("id") - ] + results = await _discover_openai_shaped_models(conn, conn.base_url) else: - # Native providers rely on curated/global catalog entries or manual rows. - return [] + provider_key = (conn.native_provider or "").upper() + provider = NATIVE_PROVIDER_PREFIX.get(provider_key, provider_key.lower()) + api_base = resolve_api_base( + provider=provider_key, + provider_prefix=provider, + config_api_base=conn.base_url, + ) + if api_base: + results = await _discover_openai_shaped_models(conn, api_base) + elif provider: + results = await _discover_litellm_native_models(conn, provider) + else: + results = [] if allowlist: results = [item for item in results if item["model_id"] in allowlist] diff --git a/surfsense_web/atoms/model-connections/model-connections-mutation.atoms.ts b/surfsense_web/atoms/model-connections/model-connections-mutation.atoms.ts index 612216bf2..76289e60d 100644 --- a/surfsense_web/atoms/model-connections/model-connections-mutation.atoms.ts +++ b/surfsense_web/atoms/model-connections/model-connections-mutation.atoms.ts @@ -90,8 +90,12 @@ export const discoverConnectionModelsMutationAtom = atomWithMutation((get) => { return { mutationKey: ["model-connections", "discover"], mutationFn: (id: number) => modelConnectionsApiService.discoverModels(id), - onSuccess: () => { - toast.success("Models discovered"); + onSuccess: (models) => { + toast.success( + models.length + ? `${models.length} models discovered` + : "No models found for this connection" + ); invalidateModelConnections(searchSpaceId); }, onError: (error: Error) => toast.error(error.message || "Failed to discover models"), From c6a25cc1fe1e128852146a09d2877d3496a2aaee Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Thu, 11 Jun 2026 18:20:53 +0530 Subject: [PATCH 17/59] refactor(model-connections): streamline global model config persistence --- .../versions/156_add_model_connections.py | 266 +++++++++++------- surfsense_backend/app/config/__init__.py | 9 +- .../app/config/global_llm_config.example.yaml | 60 ++-- surfsense_backend/app/db.py | 4 +- .../app/routes/model_connections_routes.py | 19 +- .../app/routes/new_llm_config_routes.py | 6 +- .../app/routes/search_spaces_routes.py | 6 +- .../app/schemas/model_connections.py | 6 +- .../app/services/model_connection_service.py | 85 +++--- .../tests/e2e/fixtures/global_llm_config.yaml | 8 +- .../routes/test_global_configs_is_premium.py | 12 +- ...t_global_new_llm_configs_supports_image.py | 8 +- .../unit/services/test_model_connections.py | 12 +- 13 files changed, 277 insertions(+), 224 deletions(-) diff --git a/surfsense_backend/alembic/versions/156_add_model_connections.py b/surfsense_backend/alembic/versions/156_add_model_connections.py index 0a11d7f9d..185debca4 100644 --- a/surfsense_backend/alembic/versions/156_add_model_connections.py +++ b/surfsense_backend/alembic/versions/156_add_model_connections.py @@ -20,7 +20,7 @@ depends_on: str | Sequence[str] | None = None connection_protocol = postgresql.ENUM( "OLLAMA", "OPENAI_COMPATIBLE", - "NATIVE", + "ANTHROPIC", name="connectionprotocol", create_type=False, ) @@ -39,122 +39,172 @@ model_source = postgresql.ENUM( ) +def _table_exists(table_name: str) -> bool: + return table_name in sa.inspect(op.get_bind()).get_table_names() + + +def _column_exists(table_name: str, column_name: str) -> bool: + if not _table_exists(table_name): + return False + return column_name in { + column["name"] for column in sa.inspect(op.get_bind()).get_columns(table_name) + } + + +def _index_exists(table_name: str, index_name: str) -> bool: + if not _table_exists(table_name): + return False + return index_name in { + index["name"] for index in sa.inspect(op.get_bind()).get_indexes(table_name) + } + + +def _create_index_if_missing( + index_name: str, + table_name: str, + columns: list[str], +) -> None: + if not _index_exists(table_name, index_name): + op.create_index(index_name, table_name, columns, unique=False) + + +def _add_searchspace_column_if_missing(column_name: str) -> None: + if not _column_exists("searchspaces", column_name): + op.add_column("searchspaces", sa.Column(column_name, sa.Integer(), nullable=True)) + + def upgrade() -> None: bind = op.get_bind() connection_protocol.create(bind, checkfirst=True) + op.execute("ALTER TYPE connectionprotocol ADD VALUE IF NOT EXISTS 'ANTHROPIC'") connection_scope.create(bind, checkfirst=True) model_source.create(bind, checkfirst=True) - op.create_table( + if _table_exists("connections"): + if _column_exists("connections", "native_provider") and not _column_exists( + "connections", "litellm_provider" + ): + op.alter_column( + "connections", + "native_provider", + new_column_name="litellm_provider", + existing_type=sa.String(length=100), + existing_nullable=True, + ) + elif not _column_exists("connections", "litellm_provider"): + op.add_column( + "connections", + sa.Column("litellm_provider", sa.String(length=100), nullable=True), + ) + else: + op.create_table( + "connections", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("created_at", sa.DateTime(timezone=True), nullable=False), + sa.Column("protocol", connection_protocol, nullable=False), + sa.Column("litellm_provider", sa.String(length=100), nullable=True), + sa.Column("base_url", sa.String(length=500), nullable=True), + sa.Column("api_key", sa.String(), nullable=True), + sa.Column( + "extra", + postgresql.JSONB(astext_type=sa.Text()), + server_default=sa.text("'{}'::jsonb"), + nullable=False, + ), + sa.Column("scope", connection_scope, nullable=False), + sa.Column("enabled", sa.Boolean(), server_default=sa.text("true"), nullable=False), + sa.Column("search_space_id", sa.Integer(), nullable=True), + sa.Column("user_id", sa.UUID(), nullable=True), + sa.Column("last_verified_at", sa.TIMESTAMP(timezone=True), nullable=True), + sa.Column("last_status", sa.String(length=50), nullable=True), + sa.Column("last_error", sa.Text(), nullable=True), + sa.CheckConstraint( + "(scope = 'GLOBAL' AND search_space_id IS NULL AND user_id IS NULL) OR " + "(scope = 'SEARCH_SPACE' AND search_space_id IS NOT NULL AND user_id IS NOT NULL) OR " + "(scope = 'USER' AND user_id IS NOT NULL)", + name="ck_connections_scope_owner", + ), + sa.ForeignKeyConstraint( + ["search_space_id"], ["searchspaces.id"], ondelete="CASCADE" + ), + sa.ForeignKeyConstraint(["user_id"], ["user.id"], ondelete="CASCADE"), + sa.PrimaryKeyConstraint("id"), + ) + if _index_exists("connections", "ix_connections_native_provider") and not _index_exists( + "connections", "ix_connections_litellm_provider" + ): + op.execute( + "ALTER INDEX ix_connections_native_provider " + "RENAME TO ix_connections_litellm_provider" + ) + _create_index_if_missing("ix_connections_protocol", "connections", ["protocol"]) + _create_index_if_missing( + "ix_connections_litellm_provider", "connections", - sa.Column("id", sa.Integer(), nullable=False), - sa.Column("created_at", sa.DateTime(timezone=True), nullable=False), - sa.Column("protocol", connection_protocol, nullable=False), - sa.Column("native_provider", sa.String(length=100), nullable=True), - sa.Column("base_url", sa.String(length=500), nullable=True), - sa.Column("api_key", sa.String(), nullable=True), - sa.Column( - "extra", - postgresql.JSONB(astext_type=sa.Text()), - server_default=sa.text("'{}'::jsonb"), - nullable=False, - ), - sa.Column("scope", connection_scope, nullable=False), - sa.Column("enabled", sa.Boolean(), server_default=sa.text("true"), nullable=False), - sa.Column("search_space_id", sa.Integer(), nullable=True), - sa.Column("user_id", sa.UUID(), nullable=True), - sa.Column("last_verified_at", sa.TIMESTAMP(timezone=True), nullable=True), - sa.Column("last_status", sa.String(length=50), nullable=True), - sa.Column("last_error", sa.Text(), nullable=True), - sa.CheckConstraint( - "(scope = 'GLOBAL' AND search_space_id IS NULL AND user_id IS NULL) OR " - "(scope = 'SEARCH_SPACE' AND search_space_id IS NOT NULL AND user_id IS NOT NULL) OR " - "(scope = 'USER' AND user_id IS NOT NULL)", - name="ck_connections_scope_owner", - ), - sa.ForeignKeyConstraint( - ["search_space_id"], ["searchspaces.id"], ondelete="CASCADE" - ), - sa.ForeignKeyConstraint(["user_id"], ["user.id"], ondelete="CASCADE"), - sa.PrimaryKeyConstraint("id"), + ["litellm_provider"], ) - op.create_index(op.f("ix_connections_protocol"), "connections", ["protocol"], unique=False) - op.create_index( - op.f("ix_connections_native_provider"), - "connections", - ["native_provider"], - unique=False, - ) - op.create_index(op.f("ix_connections_scope"), "connections", ["scope"], unique=False) + _create_index_if_missing("ix_connections_scope", "connections", ["scope"]) - op.create_table( - "models", - sa.Column("id", sa.Integer(), nullable=False), - sa.Column("created_at", sa.DateTime(timezone=True), nullable=False), - sa.Column("connection_id", sa.Integer(), nullable=False), - sa.Column("model_id", sa.String(length=255), nullable=False), - sa.Column("display_name", sa.String(length=255), nullable=True), - sa.Column( - "source", - model_source, - server_default="DISCOVERED", - nullable=False, - ), - sa.Column( - "capabilities", - postgresql.JSONB(astext_type=sa.Text()), - server_default=sa.text("'{}'::jsonb"), - nullable=False, - ), - sa.Column( - "capabilities_declared", - postgresql.JSONB(astext_type=sa.Text()), - server_default=sa.text("'{}'::jsonb"), - nullable=False, - ), - sa.Column( - "capabilities_verified", - postgresql.JSONB(astext_type=sa.Text()), - server_default=sa.text("'{}'::jsonb"), - nullable=False, - ), - sa.Column( - "capabilities_override", - postgresql.JSONB(astext_type=sa.Text()), - server_default=sa.text("'{}'::jsonb"), - nullable=False, - ), - sa.Column("embedding_dimension", sa.Integer(), nullable=True), - sa.Column("enabled", sa.Boolean(), server_default=sa.text("true"), nullable=False), - sa.Column("billing_tier", sa.String(length=50), nullable=True), - sa.Column( - "catalog", - postgresql.JSONB(astext_type=sa.Text()), - server_default=sa.text("'{}'::jsonb"), - nullable=False, - ), - sa.ForeignKeyConstraint(["connection_id"], ["connections.id"], ondelete="CASCADE"), - sa.PrimaryKeyConstraint("id"), - sa.UniqueConstraint( - "connection_id", "model_id", name="uq_models_connection_model_id" - ), - ) - op.create_index(op.f("ix_models_connection_id"), "models", ["connection_id"], unique=False) - op.create_index("ix_models_model_id", "models", ["model_id"], unique=False) - op.create_index(op.f("ix_models_billing_tier"), "models", ["billing_tier"], unique=False) + if not _table_exists("models"): + op.create_table( + "models", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("created_at", sa.DateTime(timezone=True), nullable=False), + sa.Column("connection_id", sa.Integer(), nullable=False), + sa.Column("model_id", sa.String(length=255), nullable=False), + sa.Column("display_name", sa.String(length=255), nullable=True), + sa.Column( + "source", + model_source, + server_default="DISCOVERED", + nullable=False, + ), + sa.Column( + "capabilities", + postgresql.JSONB(astext_type=sa.Text()), + server_default=sa.text("'{}'::jsonb"), + nullable=False, + ), + sa.Column( + "capabilities_declared", + postgresql.JSONB(astext_type=sa.Text()), + server_default=sa.text("'{}'::jsonb"), + nullable=False, + ), + sa.Column( + "capabilities_verified", + postgresql.JSONB(astext_type=sa.Text()), + server_default=sa.text("'{}'::jsonb"), + nullable=False, + ), + sa.Column( + "capabilities_override", + postgresql.JSONB(astext_type=sa.Text()), + server_default=sa.text("'{}'::jsonb"), + nullable=False, + ), + sa.Column("embedding_dimension", sa.Integer(), nullable=True), + sa.Column("enabled", sa.Boolean(), server_default=sa.text("true"), nullable=False), + sa.Column("billing_tier", sa.String(length=50), nullable=True), + sa.Column( + "catalog", + postgresql.JSONB(astext_type=sa.Text()), + server_default=sa.text("'{}'::jsonb"), + nullable=False, + ), + sa.ForeignKeyConstraint(["connection_id"], ["connections.id"], ondelete="CASCADE"), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint( + "connection_id", "model_id", name="uq_models_connection_model_id" + ), + ) + _create_index_if_missing("ix_models_connection_id", "models", ["connection_id"]) + _create_index_if_missing("ix_models_model_id", "models", ["model_id"]) + _create_index_if_missing("ix_models_billing_tier", "models", ["billing_tier"]) - op.add_column( - "searchspaces", - sa.Column("chat_model_id", sa.Integer(), nullable=True), - ) - op.add_column( - "searchspaces", - sa.Column("image_gen_model_id", sa.Integer(), nullable=True), - ) - op.add_column( - "searchspaces", - sa.Column("vision_model_id", sa.Integer(), nullable=True), - ) + _add_searchspace_column_if_missing("chat_model_id") + _add_searchspace_column_if_missing("image_gen_model_id") + _add_searchspace_column_if_missing("vision_model_id") def downgrade() -> None: @@ -168,7 +218,7 @@ def downgrade() -> None: op.drop_table("models") op.drop_index(op.f("ix_connections_scope"), table_name="connections") - op.drop_index(op.f("ix_connections_native_provider"), table_name="connections") + op.drop_index(op.f("ix_connections_litellm_provider"), table_name="connections") op.drop_index(op.f("ix_connections_protocol"), table_name="connections") op.drop_table("connections") diff --git a/surfsense_backend/app/config/__init__.py b/surfsense_backend/app/config/__init__.py index b8addb45d..fd8b29116 100644 --- a/surfsense_backend/app/config/__init__.py +++ b/surfsense_backend/app/config/__init__.py @@ -78,8 +78,7 @@ def load_global_llm_configs(): # stamps) never leak into the cached YAML structure. configs = copy.deepcopy(data.get("global_llm_configs", [])) - # Lazy import keeps the `app.config` -> `app.services` edge one-way - # and matches the `provider_api_base` pattern used elsewhere. + # Lazy import keeps the `app.config` -> `app.services` edge one-way. from app.services.provider_capabilities import derive_supports_image_input seen_slugs: dict[str, int] = {} @@ -104,7 +103,7 @@ def load_global_llm_configs(): else None ) cfg["supports_image_input"] = derive_supports_image_input( - provider=cfg.get("provider"), + litellm_provider=cfg.get("litellm_provider"), model_name=cfg.get("model_name"), base_model=base_model, custom_provider=cfg.get("custom_provider"), @@ -123,7 +122,7 @@ def load_global_llm_configs(): # Stamp Auto (Fastest) ranking metadata. YAML configs are always # Tier A — operator-curated, locked first when premium-eligible. # The OpenRouter refresh tick later re-stamps health for any cfg - # whose provider == "OPENROUTER" via _enrich_health. + # whose litellm_provider == "openrouter" via _enrich_health. try: from app.services.quality_score import static_score_yaml @@ -133,7 +132,7 @@ def load_global_llm_configs(): cfg["quality_score_static"] = static_q cfg["quality_score"] = static_q cfg["quality_score_health"] = None - # YAML cfgs whose provider is OPENROUTER are also subject + # YAML cfgs whose litellm_provider is openrouter are also subject # to health gating against their own /endpoints data — a # hand-picked dead OR model is still dead. _enrich_health # re-stamps health_gated for them on the next refresh tick. diff --git a/surfsense_backend/app/config/global_llm_config.example.yaml b/surfsense_backend/app/config/global_llm_config.example.yaml index b0eee6458..06676511f 100644 --- a/surfsense_backend/app/config/global_llm_config.example.yaml +++ b/surfsense_backend/app/config/global_llm_config.example.yaml @@ -18,7 +18,7 @@ # - Configure router_settings below to customize the load balancing behavior # # Static config shape: -# - Connection fields: provider, api_key, api_base, api_version +# - Connection fields: litellm_provider, api_key, api_base, api_version # - Model fields: model_name, billing_tier, rpm/tpm, litellm_params # - Prompt defaults: system_instructions, citations_enabled # IDs share one GLOBAL model namespace across chat, vision, and image generation. @@ -75,10 +75,10 @@ global_llm_configs: seo_enabled: true seo_slug: "gpt-4-turbo" quota_reserve_tokens: 4000 - provider: "OPENAI" + litellm_provider: "openai" model_name: "gpt-4-turbo-preview" api_key: "sk-your-openai-api-key-here" - api_base: "" + api_base: "https://api.openai.com/v1" # Rate limits for load balancing (requests/tokens per minute) rpm: 500 # Requests per minute tpm: 100000 # Tokens per minute @@ -99,10 +99,10 @@ global_llm_configs: seo_enabled: true seo_slug: "claude-3-opus" quota_reserve_tokens: 4000 - provider: "ANTHROPIC" + litellm_provider: "anthropic" model_name: "claude-3-opus-20240229" api_key: "sk-ant-your-anthropic-api-key-here" - api_base: "" + api_base: "https://api.anthropic.com/v1" rpm: 1000 tpm: 100000 litellm_params: @@ -121,10 +121,10 @@ global_llm_configs: seo_enabled: true seo_slug: "gpt-3.5-turbo-fast" quota_reserve_tokens: 2000 - provider: "OPENAI" + litellm_provider: "openai" model_name: "gpt-3.5-turbo" api_key: "sk-your-openai-api-key-here" - api_base: "" + api_base: "https://api.openai.com/v1" rpm: 3500 # GPT-3.5 has higher rate limits tpm: 200000 litellm_params: @@ -143,7 +143,7 @@ global_llm_configs: seo_enabled: true seo_slug: "deepseek-chat-chinese" quota_reserve_tokens: 4000 - provider: "DEEPSEEK" + litellm_provider: "openai" model_name: "deepseek-chat" api_key: "your-deepseek-api-key-here" api_base: "https://api.deepseek.com/v1" @@ -175,7 +175,7 @@ global_llm_configs: seo_enabled: true seo_slug: "azure-gpt-4o" quota_reserve_tokens: 4000 - provider: "AZURE" + litellm_provider: "azure" # model_name format for Azure: azure/ model_name: "azure/gpt-4o-deployment" api_key: "your-azure-api-key-here" @@ -203,7 +203,7 @@ global_llm_configs: seo_enabled: true seo_slug: "azure-gpt-4-turbo" quota_reserve_tokens: 4000 - provider: "AZURE" + litellm_provider: "azure" model_name: "azure/gpt-4-turbo-deployment" api_key: "your-azure-api-key-here" api_base: "https://your-resource.openai.azure.com" @@ -227,10 +227,10 @@ global_llm_configs: seo_enabled: true seo_slug: "groq-llama-3" quota_reserve_tokens: 8000 - provider: "GROQ" + litellm_provider: "groq" model_name: "llama3-70b-8192" api_key: "your-groq-api-key-here" - api_base: "" + api_base: "https://api.groq.com/openai/v1" rpm: 30 # Groq has lower rate limits on free tier tpm: 14400 litellm_params: @@ -249,7 +249,7 @@ global_llm_configs: seo_enabled: true seo_slug: "minimax-m3" quota_reserve_tokens: 4000 - provider: "MINIMAX" + litellm_provider: "openai" model_name: "MiniMax-M3" api_key: "your-minimax-api-key-here" api_base: "https://api.minimax.io/v1" @@ -288,10 +288,10 @@ global_llm_configs: anonymous_enabled: false seo_enabled: false quota_reserve_tokens: 1000 - provider: "OPENAI" + litellm_provider: "openai" model_name: "gpt-4o-mini" api_key: "sk-your-openai-api-key-here" - api_base: "" + api_base: "https://api.openai.com/v1" rpm: 3500 tpm: 200000 litellm_params: @@ -391,10 +391,10 @@ global_image_generation_configs: - id: -2001 name: "Global DALL-E 3" description: "OpenAI's DALL-E 3 for high-quality image generation" - provider: "OPENAI" + litellm_provider: "openai" model_name: "dall-e-3" api_key: "sk-your-openai-api-key-here" - api_base: "" + api_base: "https://api.openai.com/v1" rpm: 50 # Requests per minute (image gen is rate-limited by RPM, not tokens) litellm_params: {} @@ -402,10 +402,10 @@ global_image_generation_configs: - id: -2002 name: "Global GPT Image 1" description: "OpenAI's GPT Image 1 model" - provider: "OPENAI" + litellm_provider: "openai" model_name: "gpt-image-1" api_key: "sk-your-openai-api-key-here" - api_base: "" + api_base: "https://api.openai.com/v1" rpm: 50 litellm_params: {} @@ -413,7 +413,7 @@ global_image_generation_configs: - id: -2003 name: "Global Azure DALL-E 3" description: "Azure-hosted DALL-E 3 deployment" - provider: "AZURE_OPENAI" + litellm_provider: "azure" model_name: "azure/dall-e-3-deployment" api_key: "your-azure-api-key-here" api_base: "https://your-resource.openai.azure.com" @@ -426,10 +426,10 @@ global_image_generation_configs: # - id: -2004 # name: "Global Gemini Image Gen" # description: "Google Gemini image generation via OpenRouter" - # provider: "OPENROUTER" + # litellm_provider: "openrouter" # model_name: "google/gemini-2.5-flash-image" # api_key: "your-openrouter-api-key-here" - # api_base: "" + # api_base: "https://openrouter.ai/api/v1" # rpm: 30 # litellm_params: {} @@ -455,10 +455,10 @@ global_vision_llm_configs: - id: -1001 name: "Global GPT-4o Vision" description: "OpenAI's GPT-4o with strong vision capabilities" - provider: "OPENAI" + litellm_provider: "openai" model_name: "gpt-4o" api_key: "sk-your-openai-api-key-here" - api_base: "" + api_base: "https://api.openai.com/v1" rpm: 500 tpm: 100000 litellm_params: @@ -469,10 +469,10 @@ global_vision_llm_configs: - id: -1002 name: "Global Gemini 2.0 Flash" description: "Google's fast vision model with large context" - provider: "GOOGLE" + litellm_provider: "gemini" model_name: "gemini-2.0-flash" api_key: "your-google-ai-api-key-here" - api_base: "" + api_base: "https://generativelanguage.googleapis.com/v1beta" rpm: 1000 tpm: 200000 litellm_params: @@ -483,10 +483,10 @@ global_vision_llm_configs: - id: -1003 name: "Global Claude 3.5 Sonnet Vision" description: "Anthropic's Claude 3.5 Sonnet with vision support" - provider: "ANTHROPIC" + litellm_provider: "anthropic" model_name: "claude-3-5-sonnet-20241022" api_key: "sk-ant-your-anthropic-api-key-here" - api_base: "" + api_base: "https://api.anthropic.com/v1" rpm: 1000 tpm: 100000 litellm_params: @@ -497,7 +497,7 @@ global_vision_llm_configs: # - id: -1004 # name: "Global Azure GPT-4o Vision" # description: "Azure-hosted GPT-4o for vision analysis" - # provider: "AZURE_OPENAI" + # litellm_provider: "azure" # model_name: "azure/gpt-4o-deployment" # api_key: "your-azure-api-key-here" # api_base: "https://your-resource.openai.azure.com" @@ -518,7 +518,7 @@ global_vision_llm_configs: # - system_instructions: Custom prompt or empty string to use defaults # - use_default_system_instructions: true = use SURFSENSE_SYSTEM_INSTRUCTIONS when system_instructions is empty # - citations_enabled: true = include citation instructions, false = include anti-citation instructions -# - All standard LiteLLM providers are supported +# - All standard LiteLLM provider adapter names are supported # - rpm/tpm: Optional rate limits for load balancing (requests/tokens per minute) # These help the router distribute load evenly and avoid rate limit errors # diff --git a/surfsense_backend/app/db.py b/surfsense_backend/app/db.py index 9756cb32f..4c628b05a 100644 --- a/surfsense_backend/app/db.py +++ b/surfsense_backend/app/db.py @@ -283,7 +283,7 @@ class VisionProvider(StrEnum): class ConnectionProtocol(StrEnum): OLLAMA = "OLLAMA" OPENAI_COMPATIBLE = "OPENAI_COMPATIBLE" - NATIVE = "NATIVE" + ANTHROPIC = "ANTHROPIC" class ConnectionScope(StrEnum): @@ -1663,7 +1663,7 @@ class Connection(BaseModel, TimestampMixin): __tablename__ = "connections" protocol = Column(SQLAlchemyEnum(ConnectionProtocol), nullable=False, index=True) - native_provider = Column(String(100), nullable=True, index=True) + litellm_provider = Column(String(100), nullable=True, index=True) base_url = Column(String(500), nullable=True) api_key = Column(String, nullable=True) extra = Column(JSONB, nullable=False, default=dict, server_default="{}") diff --git a/surfsense_backend/app/routes/model_connections_routes.py b/surfsense_backend/app/routes/model_connections_routes.py index 5872671b1..cae951c3a 100644 --- a/surfsense_backend/app/routes/model_connections_routes.py +++ b/surfsense_backend/app/routes/model_connections_routes.py @@ -8,6 +8,7 @@ from sqlalchemy.orm import selectinload from app.config import config from app.db import ( Connection, + ConnectionProtocol, ConnectionScope, Model, ModelSource, @@ -40,6 +41,16 @@ router = APIRouter() logger = logging.getLogger(__name__) +def _default_litellm_provider(protocol: ConnectionProtocol | str) -> str: + protocol_value = getattr(protocol, "value", str(protocol)) + defaults = { + ConnectionProtocol.OLLAMA.value: "ollama_chat", + ConnectionProtocol.ANTHROPIC.value: "anthropic", + ConnectionProtocol.OPENAI_COMPATIBLE.value: "openai", + } + return defaults.get(protocol_value, "openai") + + def _model_read(model: Model | dict) -> ModelRead: return ModelRead.model_validate(model) @@ -58,7 +69,7 @@ def _connection_read(conn: Connection | dict, models: list[Model | dict] | None return ConnectionRead( id=conn.id, protocol=conn.protocol, - native_provider=conn.native_provider, + litellm_provider=conn.litellm_provider, base_url=conn.base_url, extra=conn.extra or {}, scope=conn.scope, @@ -168,8 +179,12 @@ async def create_connection( Permission.LLM_CONFIGS_CREATE.value, "You don't have permission to create model connections in this search space", ) + payload = data.model_dump(exclude={"search_space_id"}) + if not payload.get("litellm_provider"): + payload["litellm_provider"] = _default_litellm_provider(data.protocol) + conn = Connection( - **data.model_dump(exclude={"search_space_id"}), + **payload, search_space_id=data.search_space_id if data.scope == ConnectionScope.SEARCH_SPACE else None, user_id=user.id, ) diff --git a/surfsense_backend/app/routes/new_llm_config_routes.py b/surfsense_backend/app/routes/new_llm_config_routes.py index 84d66bb13..531fa8730 100644 --- a/surfsense_backend/app/routes/new_llm_config_routes.py +++ b/surfsense_backend/app/routes/new_llm_config_routes.py @@ -57,7 +57,7 @@ def _serialize_byok_config(config: NewLLMConfig) -> NewLLMConfigRead: litellm_params.get("base_model") if isinstance(litellm_params, dict) else None ) supports_image_input = derive_supports_image_input( - provider=provider_value, + litellm_provider=provider_value.lower(), model_name=config.model_name, base_model=base_model, custom_provider=config.custom_provider, @@ -147,7 +147,7 @@ async def get_global_new_llm_configs( else None ) supports_image_input = derive_supports_image_input( - provider=cfg.get("provider"), + litellm_provider=cfg.get("litellm_provider"), model_name=cfg.get("model_name"), base_model=cfg_base_model, custom_provider=cfg.get("custom_provider"), @@ -157,7 +157,7 @@ async def get_global_new_llm_configs( "id": cfg.get("id"), "name": cfg.get("name"), "description": cfg.get("description"), - "provider": cfg.get("provider"), + "provider": cfg.get("litellm_provider"), "custom_provider": cfg.get("custom_provider"), "model_name": cfg.get("model_name"), "api_base": cfg.get("api_base") or None, diff --git a/surfsense_backend/app/routes/search_spaces_routes.py b/surfsense_backend/app/routes/search_spaces_routes.py index 898077b7a..2cda04221 100644 --- a/surfsense_backend/app/routes/search_spaces_routes.py +++ b/surfsense_backend/app/routes/search_spaces_routes.py @@ -419,7 +419,7 @@ async def _get_llm_config_by_id( "id": cfg.get("id"), "name": cfg.get("name"), "description": cfg.get("description"), - "provider": cfg.get("provider"), + "provider": cfg.get("litellm_provider"), "custom_provider": cfg.get("custom_provider"), "model_name": cfg.get("model_name"), "api_base": cfg.get("api_base"), @@ -490,7 +490,7 @@ async def _get_image_gen_config_by_id( "id": cfg.get("id"), "name": cfg.get("name"), "description": cfg.get("description"), - "provider": cfg.get("provider"), + "provider": cfg.get("litellm_provider"), "custom_provider": cfg.get("custom_provider"), "model_name": cfg.get("model_name"), "api_base": cfg.get("api_base") or None, @@ -550,7 +550,7 @@ async def _get_vision_llm_config_by_id( "id": cfg.get("id"), "name": cfg.get("name"), "description": cfg.get("description"), - "provider": cfg.get("provider"), + "provider": cfg.get("litellm_provider"), "custom_provider": cfg.get("custom_provider"), "model_name": cfg.get("model_name"), "api_base": cfg.get("api_base") or None, diff --git a/surfsense_backend/app/schemas/model_connections.py b/surfsense_backend/app/schemas/model_connections.py index ea1ec4e88..306dd63c8 100644 --- a/surfsense_backend/app/schemas/model_connections.py +++ b/surfsense_backend/app/schemas/model_connections.py @@ -29,7 +29,7 @@ class ModelRead(BaseModel): class ConnectionRead(BaseModel): id: int protocol: ConnectionProtocol | str - native_provider: str | None = None + litellm_provider: str | None = None base_url: str | None = None extra: dict[str, Any] = Field(default_factory=dict) scope: ConnectionScope | str @@ -48,7 +48,7 @@ class ConnectionRead(BaseModel): class ConnectionCreate(BaseModel): protocol: ConnectionProtocol - native_provider: str | None = None + litellm_provider: str | None = Field(None, max_length=100) base_url: str | None = Field(None, max_length=500) api_key: str | None = None extra: dict[str, Any] = Field(default_factory=dict) @@ -58,7 +58,7 @@ class ConnectionCreate(BaseModel): class ConnectionUpdate(BaseModel): - native_provider: str | None = None + litellm_provider: str | None = Field(None, max_length=100) base_url: str | None = Field(None, max_length=500) api_key: str | None = None extra: dict[str, Any] | None = None diff --git a/surfsense_backend/app/services/model_connection_service.py b/surfsense_backend/app/services/model_connection_service.py index 42a4792a4..5e5b231f9 100644 --- a/surfsense_backend/app/services/model_connection_service.py +++ b/surfsense_backend/app/services/model_connection_service.py @@ -2,7 +2,6 @@ from __future__ import annotations -import asyncio import contextlib import logging from dataclasses import dataclass @@ -13,8 +12,7 @@ import httpx import litellm from app.db import Connection, ConnectionProtocol, Model, ModelSource -from app.services.model_resolver import NATIVE_PROVIDER_PREFIX, ensure_v1, to_litellm -from app.services.provider_api_base import resolve_api_base +from app.services.model_resolver import ensure_v1, to_litellm logger = logging.getLogger(__name__) @@ -36,6 +34,13 @@ def _auth_headers(conn: Connection) -> dict[str, str]: return {"Authorization": f"Bearer {conn.api_key}"} +def _anthropic_headers(conn: Connection) -> dict[str, str]: + headers = {"anthropic-version": "2023-06-01"} + if conn.api_key: + headers["x-api-key"] = conn.api_key + return headers + + def _docker_hint(url: str | None, exc_or_status: Any) -> str: raw = str(exc_or_status) if not url: @@ -56,24 +61,26 @@ def _docker_hint(url: str | None, exc_or_status: Any) -> str: async def verify_connection(conn: Connection) -> VerifyResult: - if not conn.base_url and conn.protocol in ( - ConnectionProtocol.OLLAMA, - ConnectionProtocol.OPENAI_COMPATIBLE, - ): + if not conn.base_url: return VerifyResult("UNREACHABLE", False, "Base URL is required.") if conn.protocol == ConnectionProtocol.OLLAMA: url = f"{conn.base_url.rstrip('/')}/api/version" elif conn.protocol == ConnectionProtocol.OPENAI_COMPATIBLE: url = f"{ensure_v1(conn.base_url)}/models" + elif conn.protocol == ConnectionProtocol.ANTHROPIC: + url = f"{conn.base_url.rstrip('/')}/models" else: - # Native providers do not share one cheap health endpoint. The model - # probe exercises the real path and is the authoritative check. - return VerifyResult("OK", True, "Native provider configuration accepted.") + return VerifyResult("UNREACHABLE", False, "Unsupported connection protocol.") try: async with httpx.AsyncClient(timeout=VERIFY_TIMEOUT_SECONDS) as client: - response = await client.get(url, headers=_auth_headers(conn)) + headers = ( + _anthropic_headers(conn) + if conn.protocol == ConnectionProtocol.ANTHROPIC + else _auth_headers(conn) + ) + response = await client.get(url, headers=headers) if response.status_code in (401, 403): return VerifyResult("AUTH_FAILED", False, "Authentication failed.") if response.status_code == 404: @@ -156,39 +163,25 @@ async def _discover_openai_shaped_models(conn: Connection, base_url: str | None) ] -def _litellm_valid_model_ids(provider: str, api_key: str | None) -> list[str]: - if not api_key: +async def _discover_anthropic_models(conn: Connection) -> list[dict[str, Any]]: + if not conn.base_url: return [] - try: - models = litellm.get_valid_models( - check_provider_endpoint=True, - custom_llm_provider=provider, - api_key=api_key, - ) - except Exception as exc: - logger.warning("LiteLLM model discovery failed for provider %s: %s", provider, exc) - return [] - - provider_prefix = f"{provider}/" - return [ - model.removeprefix(provider_prefix) - for model in models - if isinstance(model, str) and model.strip() - ] - - -async def _discover_litellm_native_models(conn: Connection, provider: str) -> list[dict[str, Any]]: - model_ids = await asyncio.to_thread(_litellm_valid_model_ids, provider, conn.api_key) + url = f"{conn.base_url.rstrip('/')}/models" + async with httpx.AsyncClient(timeout=DISCOVERY_TIMEOUT_SECONDS) as client: + response = await client.get(url, headers=_anthropic_headers(conn)) + response.raise_for_status() + models = response.json().get("data", []) return [ { - "model_id": model_id, - "display_name": model_id, + "model_id": item.get("id"), + "display_name": item.get("display_name") or item.get("id"), "source": ModelSource.DISCOVERED, - "capabilities": derive_capabilities(conn, model_id), - "metadata": {}, + "capabilities": derive_capabilities(conn, item.get("id"), item), + "metadata": item, } - for model_id in model_ids + for item in models + if item.get("id") ] @@ -231,20 +224,10 @@ async def discover_models(conn: Connection) -> list[dict[str, Any]]: ] elif conn.protocol == ConnectionProtocol.OPENAI_COMPATIBLE: results = await _discover_openai_shaped_models(conn, conn.base_url) + elif conn.protocol == ConnectionProtocol.ANTHROPIC: + results = await _discover_anthropic_models(conn) else: - provider_key = (conn.native_provider or "").upper() - provider = NATIVE_PROVIDER_PREFIX.get(provider_key, provider_key.lower()) - api_base = resolve_api_base( - provider=provider_key, - provider_prefix=provider, - config_api_base=conn.base_url, - ) - if api_base: - results = await _discover_openai_shaped_models(conn, api_base) - elif provider: - results = await _discover_litellm_native_models(conn, provider) - else: - results = [] + results = [] if allowlist: results = [item for item in results if item["model_id"] in allowlist] diff --git a/surfsense_backend/tests/e2e/fixtures/global_llm_config.yaml b/surfsense_backend/tests/e2e/fixtures/global_llm_config.yaml index 017fa1eb3..9ea5e1a29 100644 --- a/surfsense_backend/tests/e2e/fixtures/global_llm_config.yaml +++ b/surfsense_backend/tests/e2e/fixtures/global_llm_config.yaml @@ -19,7 +19,7 @@ # so the resolved auto-pin id is never sent to a real LLM provider. # The values below only need to pass # auto_model_pin_service._is_usable_global_config() -# which requires id / model_name / provider / api_key all truthy. +# which requires id / model_name / litellm_provider / api_key all truthy. # # Why TWO entries (premium + free): # auto_model_pin_service.resolve_or_get_pinned_llm_config_id() splits @@ -44,9 +44,10 @@ global_llm_configs: anonymous_enabled: false seo_enabled: false quality_score: 1.0 - provider: "OPENAI" + litellm_provider: "openai" model_name: "fake-e2e-model-premium" api_key: "fake-e2e-api-key-not-for-production" + api_base: "https://api.openai.com/v1" supports_image_input: false quota_reserve_tokens: 1024 rpm: 1000 @@ -60,9 +61,10 @@ global_llm_configs: anonymous_enabled: false seo_enabled: false quality_score: 1.0 - provider: "OPENAI" + litellm_provider: "openai" model_name: "fake-e2e-model-free" api_key: "fake-e2e-api-key-not-for-production" + api_base: "https://api.openai.com/v1" supports_image_input: false quota_reserve_tokens: 1024 rpm: 1000 diff --git a/surfsense_backend/tests/unit/routes/test_global_configs_is_premium.py b/surfsense_backend/tests/unit/routes/test_global_configs_is_premium.py index 2b6c76485..fff61f14b 100644 --- a/surfsense_backend/tests/unit/routes/test_global_configs_is_premium.py +++ b/surfsense_backend/tests/unit/routes/test_global_configs_is_premium.py @@ -25,7 +25,7 @@ _IMAGE_FIXTURE: list[dict] = [ { "id": -1, "name": "DALL-E 3", - "provider": "OPENAI", + "litellm_provider": "openai", "model_name": "dall-e-3", "api_key": "sk-test", "billing_tier": "free", @@ -33,7 +33,7 @@ _IMAGE_FIXTURE: list[dict] = [ { "id": -2, "name": "GPT-Image 1 (premium)", - "provider": "OPENAI", + "litellm_provider": "openai", "model_name": "gpt-image-1", "api_key": "sk-test", "billing_tier": "premium", @@ -41,7 +41,7 @@ _IMAGE_FIXTURE: list[dict] = [ { "id": -20_001, "name": "google/gemini-2.5-flash-image (OpenRouter)", - "provider": "OPENROUTER", + "litellm_provider": "openrouter", "model_name": "google/gemini-2.5-flash-image", "api_key": "sk-or-test", "api_base": "https://openrouter.ai/api/v1", @@ -54,7 +54,7 @@ _VISION_FIXTURE: list[dict] = [ { "id": -1, "name": "GPT-4o Vision", - "provider": "OPENAI", + "litellm_provider": "openai", "model_name": "gpt-4o", "api_key": "sk-test", "billing_tier": "free", @@ -62,7 +62,7 @@ _VISION_FIXTURE: list[dict] = [ { "id": -2, "name": "Claude 3.5 Sonnet (premium)", - "provider": "ANTHROPIC", + "litellm_provider": "anthropic", "model_name": "claude-3-5-sonnet", "api_key": "sk-ant-test", "billing_tier": "premium", @@ -70,7 +70,7 @@ _VISION_FIXTURE: list[dict] = [ { "id": -30_001, "name": "openai/gpt-4o (OpenRouter)", - "provider": "OPENROUTER", + "litellm_provider": "openrouter", "model_name": "openai/gpt-4o", "api_key": "sk-or-test", "api_base": "https://openrouter.ai/api/v1", diff --git a/surfsense_backend/tests/unit/routes/test_global_new_llm_configs_supports_image.py b/surfsense_backend/tests/unit/routes/test_global_new_llm_configs_supports_image.py index b47d9134b..67d1112f3 100644 --- a/surfsense_backend/tests/unit/routes/test_global_new_llm_configs_supports_image.py +++ b/surfsense_backend/tests/unit/routes/test_global_new_llm_configs_supports_image.py @@ -26,7 +26,7 @@ _FIXTURE: list[dict] = [ "id": -1, "name": "GPT-4o (explicit true)", "description": "vision-capable, explicit YAML override", - "provider": "OPENAI", + "litellm_provider": "openai", "model_name": "gpt-4o", "api_key": "sk-test", "billing_tier": "free", @@ -36,7 +36,7 @@ _FIXTURE: list[dict] = [ "id": -2, "name": "DeepSeek V3 (explicit false)", "description": "OpenRouter dynamic — modality-derived false", - "provider": "OPENROUTER", + "litellm_provider": "openrouter", "model_name": "deepseek/deepseek-v3.2-exp", "api_key": "sk-or-test", "api_base": "https://openrouter.ai/api/v1", @@ -47,7 +47,7 @@ _FIXTURE: list[dict] = [ "id": -10_010, "name": "Unannotated GPT-4o", "description": "no flag set — resolver should derive True via LiteLLM", - "provider": "OPENAI", + "litellm_provider": "openai", "model_name": "gpt-4o", "api_key": "sk-test", "billing_tier": "free", @@ -57,7 +57,7 @@ _FIXTURE: list[dict] = [ "id": -10_011, "name": "Unannotated unknown model", "description": "unmapped — default-allow True", - "provider": "CUSTOM", + "litellm_provider": "custom", "custom_provider": "brand_new_proxy", "model_name": "brand-new-model-x9", "api_key": "sk-test", diff --git a/surfsense_backend/tests/unit/services/test_model_connections.py b/surfsense_backend/tests/unit/services/test_model_connections.py index 98042501b..797f794b1 100644 --- a/surfsense_backend/tests/unit/services/test_model_connections.py +++ b/surfsense_backend/tests/unit/services/test_model_connections.py @@ -2,11 +2,12 @@ from app.services.global_model_catalog import materialize_global_model_catalog from app.services.model_resolver import ensure_v1, to_litellm -def test_openai_compatible_resolver_normalizes_v1() -> None: +def test_openai_compatible_resolver_uses_explicit_api_base() -> None: model, kwargs = to_litellm( { "protocol": "OPENAI_COMPATIBLE", - "base_url": "http://host.docker.internal:1234", + "litellm_provider": "openai", + "base_url": "http://host.docker.internal:1234/v1", "api_key": "local-key", "extra": {}, }, @@ -23,6 +24,7 @@ def test_ollama_resolver_uses_native_api_base() -> None: model, kwargs = to_litellm( { "protocol": "OLLAMA", + "litellm_provider": "ollama_chat", "base_url": "http://host.docker.internal:11434", "api_key": None, "extra": {}, @@ -40,9 +42,10 @@ def test_global_materialization_preserves_tier_and_keeps_key_server_side() -> No { "id": -101, "name": "OpenRouter Free", - "provider": "OPENROUTER", + "litellm_provider": "openrouter", "model_name": "meta-llama/llama-3.1-8b-instruct:free", "api_key": "sk-global-secret", + "api_base": "https://openrouter.ai/api/v1", "billing_tier": "free", "anonymous_enabled": True, "seo_enabled": True, @@ -52,9 +55,10 @@ def test_global_materialization_preserves_tier_and_keeps_key_server_side() -> No { "id": -102, "name": "OpenRouter Premium", - "provider": "OPENROUTER", + "litellm_provider": "openrouter", "model_name": "anthropic/claude-sonnet-4", "api_key": "sk-global-secret", + "api_base": "https://openrouter.ai/api/v1", "billing_tier": "premium", }, ], From 8f20a3257189807136ddf56fa942683bf4b366ce Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Thu, 11 Jun 2026 18:21:07 +0530 Subject: [PATCH 18/59] refactor(model-connections): consolidate provider capability handling --- .../app/services/global_model_catalog.py | 2 +- .../openrouter_integration_service.py | 21 +--- .../app/services/pricing_registration.py | 7 +- .../app/services/provider_api_base.py | 106 ----------------- .../app/services/provider_capabilities.py | 31 +---- .../app/services/quality_score.py | 38 +++---- .../test_openrouter_integration_service.py | 12 +- .../services/test_pricing_registration.py | 17 ++- .../unit/services/test_provider_api_base.py | 107 ------------------ .../services/test_provider_capabilities.py | 28 ++--- .../tests/unit/services/test_quality_score.py | 6 +- 11 files changed, 64 insertions(+), 311 deletions(-) delete mode 100644 surfsense_backend/app/services/provider_api_base.py delete mode 100644 surfsense_backend/tests/unit/services/test_provider_api_base.py diff --git a/surfsense_backend/app/services/global_model_catalog.py b/surfsense_backend/app/services/global_model_catalog.py index a43f58b9e..e40b3a942 100644 --- a/surfsense_backend/app/services/global_model_catalog.py +++ b/surfsense_backend/app/services/global_model_catalog.py @@ -19,7 +19,7 @@ def _connection_key(conn: dict[str, Any]) -> tuple[Any, ...]: # the same provider/base can have different quota/rate limits upstream. return ( conn.get("protocol"), - conn.get("native_provider"), + conn.get("litellm_provider"), conn.get("base_url"), conn.get("api_key"), _freeze(conn.get("extra") or {}), diff --git a/surfsense_backend/app/services/openrouter_integration_service.py b/surfsense_backend/app/services/openrouter_integration_service.py index 6454e2d58..6996f0fde 100644 --- a/surfsense_backend/app/services/openrouter_integration_service.py +++ b/surfsense_backend/app/services/openrouter_integration_service.py @@ -323,10 +323,10 @@ def _generate_configs( "seo_enabled": seo_enabled, "seo_slug": None, "quota_reserve_tokens": quota_reserve_tokens, - "provider": "OPENROUTER", + "litellm_provider": "openrouter", "model_name": model_id, "api_key": api_key, - "api_base": "", + "api_base": "https://openrouter.ai/api/v1", "rpm": free_rpm if tier == "free" else rpm, "tpm": free_tpm if tier == "free" else tpm, "litellm_params": dict(litellm_params), @@ -420,14 +420,9 @@ def _generate_image_gen_configs( "id": _stable_config_id(model_id, id_offset, taken), "name": name, "description": f"{name} via OpenRouter (image generation)", - "provider": "OPENROUTER", + "litellm_provider": "openrouter", "model_name": model_id, "api_key": api_key, - # Pin to OpenRouter's public base URL so a downstream call site - # that forgets ``resolve_api_base`` still doesn't inherit - # ``AZURE_OPENAI_ENDPOINT`` and 404 on - # ``image_generation/transformation`` (defense-in-depth, see - # ``provider_api_base`` docstring). "api_base": "https://openrouter.ai/api/v1", "api_version": None, "rpm": free_rpm if tier == "free" else rpm, @@ -504,13 +499,9 @@ def _generate_vision_llm_configs( "id": _stable_config_id(model_id, id_offset, taken), "name": name, "description": f"{name} via OpenRouter (vision)", - "provider": "OPENROUTER", + "litellm_provider": "openrouter", "model_name": model_id, "api_key": api_key, - # Pin to OpenRouter's public base URL so a downstream call site - # that forgets ``resolve_api_base`` still doesn't inherit - # ``AZURE_OPENAI_ENDPOINT`` (defense-in-depth, see - # ``provider_api_base`` docstring). "api_base": "https://openrouter.ai/api/v1", "api_version": None, "rpm": free_rpm if tier == "free" else rpm, @@ -710,7 +701,7 @@ class OpenRouterIntegrationService: ) # Re-blend health scores against the freshly fetched catalogue. Also - # re-stamps health for any YAML-curated cfg with provider==OPENROUTER + # re-stamps health for any YAML-curated cfg with litellm_provider=openrouter # so a hand-picked dead OR model is gated like a dynamic one. await self._enrich_health_safely(static_configs + new_configs, log_summary=True) @@ -787,7 +778,7 @@ class OpenRouterIntegrationService: the entire previous cycle's cache for this run. """ or_cfgs = [ - c for c in configs if str(c.get("provider", "")).upper() == "OPENROUTER" + c for c in configs if str(c.get("litellm_provider", "")).lower() == "openrouter" ] if not or_cfgs: return diff --git a/surfsense_backend/app/services/pricing_registration.py b/surfsense_backend/app/services/pricing_registration.py index de98e50c2..6b99fe723 100644 --- a/surfsense_backend/app/services/pricing_registration.py +++ b/surfsense_backend/app/services/pricing_registration.py @@ -143,12 +143,12 @@ def _register_chat_shape_configs( sample_keys: list[str] = [] for cfg in configs: - provider = str(cfg.get("provider") or "").upper() + provider = str(cfg.get("litellm_provider") or "").lower() model_name = str(cfg.get("model_name") or "").strip() litellm_params = cfg.get("litellm_params") or {} base_model = str(litellm_params.get("base_model") or model_name).strip() - if provider == "OPENROUTER": + if provider == "openrouter": entry = or_pricing.get(model_name) if entry: input_cost = _safe_float(entry.get("prompt")) @@ -189,12 +189,11 @@ def _register_chat_shape_configs( skipped_no_pricing += 1 continue aliases = _alias_set_for_yaml(provider, model_name, base_model) - provider_slug = "azure" if provider == "AZURE_OPENAI" else provider.lower() count = _register( aliases, input_cost=input_cost, output_cost=output_cost, - provider=provider_slug, + provider=provider, ) if count > 0: registered_models += 1 diff --git a/surfsense_backend/app/services/provider_api_base.py b/surfsense_backend/app/services/provider_api_base.py deleted file mode 100644 index dca1f9462..000000000 --- a/surfsense_backend/app/services/provider_api_base.py +++ /dev/null @@ -1,106 +0,0 @@ -"""Provider-aware ``api_base`` resolution shared by chat / image-gen / vision. - -LiteLLM falls back to the module-global ``litellm.api_base`` when an -individual call doesn't pass one, which silently inherits provider-agnostic -env vars like ``AZURE_OPENAI_ENDPOINT`` / ``OPENAI_API_BASE``. Without an -explicit ``api_base``, an ``openrouter/`` request can end up at an -Azure endpoint and 404 with ``Resource not found`` (real reproducer: -[litellm/llms/openrouter/image_generation/transformation.py:242-263] appends -``/chat/completions`` to whatever inherited base it gets, regardless of -provider). - -The chat router has had this defense for a while -(``llm_router_service.py:466-478``). This module hoists the maps + cascade -into a tiny standalone helper so vision and image-gen can share the same -source of truth without an inter-service circular import. -""" - -from __future__ import annotations - -PROVIDER_DEFAULT_API_BASE: dict[str, str] = { - "openrouter": "https://openrouter.ai/api/v1", - "groq": "https://api.groq.com/openai/v1", - "mistral": "https://api.mistral.ai/v1", - "perplexity": "https://api.perplexity.ai", - "xai": "https://api.x.ai/v1", - "cerebras": "https://api.cerebras.ai/v1", - "deepinfra": "https://api.deepinfra.com/v1/openai", - "fireworks_ai": "https://api.fireworks.ai/inference/v1", - "together_ai": "https://api.together.xyz/v1", - "anyscale": "https://api.endpoints.anyscale.com/v1", - "cometapi": "https://api.cometapi.com/v1", - "sambanova": "https://api.sambanova.ai/v1", -} -"""Default ``api_base`` per LiteLLM provider prefix (lowercase). - -Only providers with a well-known, stable public base URL are listed — -self-hosted / BYO-endpoint providers (ollama, custom, bedrock, vertex_ai, -huggingface, databricks, cloudflare, replicate) are intentionally omitted -so their existing config-driven behaviour is preserved.""" - - -PROVIDER_KEY_DEFAULT_API_BASE: dict[str, str] = { - "DEEPSEEK": "https://api.deepseek.com/v1", - "ALIBABA_QWEN": "https://dashscope-intl.aliyuncs.com/compatible-mode/v1", - "MOONSHOT": "https://api.moonshot.ai/v1", - "ZHIPU": "https://open.bigmodel.cn/api/paas/v4", - "MINIMAX": "https://api.minimax.io/v1", -} -"""Canonical provider key (uppercase) → base URL. - -Used when the LiteLLM provider prefix is the generic ``openai`` shim but the -config's ``provider`` field tells us which API it actually is (DeepSeek, -Alibaba, Moonshot, Zhipu, MiniMax all use the ``openai`` prefix but each -has its own base URL).""" - - -def resolve_api_base( - *, - provider: str | None, - provider_prefix: str | None, - config_api_base: str | None, -) -> str | None: - """Resolve a non-Azure-leaking ``api_base`` for a deployment. - - Cascade (first non-empty wins): - 1. The config's own ``api_base`` (whitespace-only treated as missing). - 2. ``PROVIDER_KEY_DEFAULT_API_BASE[provider.upper()]``. - 3. ``PROVIDER_DEFAULT_API_BASE[provider_prefix.lower()]``. - 4. ``None`` — caller should NOT set ``api_base`` and let the LiteLLM - provider integration apply its own default (e.g. AzureOpenAI's - deployment-derived URL, custom provider's per-deployment URL). - - Args: - provider: The config's ``provider`` field (e.g. ``"OPENROUTER"``, - ``"DEEPSEEK"``). Case-insensitive. - provider_prefix: The LiteLLM model-string prefix the same call - site builds for the model id (e.g. ``"openrouter"``, - ``"groq"``). Case-insensitive. - config_api_base: ``api_base`` from the global YAML / DB row / - OpenRouter dynamic config. Empty / whitespace-only means - "missing" — the resolver still applies the cascade. - - Returns: - A URL string, or ``None`` if no default applies for this provider. - """ - if config_api_base and config_api_base.strip(): - return config_api_base - - if provider: - key_default = PROVIDER_KEY_DEFAULT_API_BASE.get(provider.upper()) - if key_default: - return key_default - - if provider_prefix: - prefix_default = PROVIDER_DEFAULT_API_BASE.get(provider_prefix.lower()) - if prefix_default: - return prefix_default - - return None - - -__all__ = [ - "PROVIDER_DEFAULT_API_BASE", - "PROVIDER_KEY_DEFAULT_API_BASE", - "resolve_api_base", -] diff --git a/surfsense_backend/app/services/provider_capabilities.py b/surfsense_backend/app/services/provider_capabilities.py index 9e1433214..9521ef7a4 100644 --- a/surfsense_backend/app/services/provider_capabilities.py +++ b/surfsense_backend/app/services/provider_capabilities.py @@ -46,26 +46,12 @@ from collections.abc import Iterable import litellm -from app.services.model_resolver import NATIVE_PROVIDER_PREFIX - logger = logging.getLogger(__name__) -# Provider-name → LiteLLM model-prefix map. -# -# Owned here because ``app.services.provider_capabilities`` is the -# only edge that's safe to call from ``app.config``'s YAML loader at -# class-body init time. ``app.agents.chat.runtime.llm_config`` re-exports -# this constant under the historical ``PROVIDER_MAP`` name; placing the -# map there directly would re-introduce the -# ``app.config -> ... -> deliverables/tools/generate_image -> -# app.config`` cycle that prompted the move. -_PROVIDER_PREFIX_MAP = NATIVE_PROVIDER_PREFIX - - def _candidate_model_strings( *, - provider: str | None, + litellm_provider: str | None, model_name: str | None, base_model: str | None, custom_provider: str | None, @@ -92,12 +78,7 @@ def _candidate_model_strings( seen.add(key) candidates.append(key) - provider_prefix: str | None = None - if provider: - provider_prefix = _PROVIDER_PREFIX_MAP.get(provider.upper(), provider.lower()) - if custom_provider: - # ``custom_provider`` overrides everything for CUSTOM/proxy setups. - provider_prefix = custom_provider + provider_prefix = custom_provider or litellm_provider primary_model = base_model or model_name bare_model = model_name @@ -132,7 +113,7 @@ def _candidate_model_strings( def derive_supports_image_input( *, - provider: str | None = None, + litellm_provider: str | None = None, model_name: str | None = None, base_model: str | None = None, custom_provider: str | None = None, @@ -166,7 +147,7 @@ def derive_supports_image_input( return False for model_string, custom_llm_provider in _candidate_model_strings( - provider=provider, + litellm_provider=litellm_provider, model_name=model_name, base_model=base_model, custom_provider=custom_provider, @@ -191,7 +172,7 @@ def derive_supports_image_input( def is_known_text_only_chat_model( *, - provider: str | None = None, + litellm_provider: str | None = None, model_name: str | None = None, base_model: str | None = None, custom_provider: str | None = None, @@ -212,7 +193,7 @@ def is_known_text_only_chat_model( leads to the regression we're fixing here. """ for model_string, custom_llm_provider in _candidate_model_strings( - provider=provider, + litellm_provider=litellm_provider, model_name=model_name, base_model=base_model, custom_provider=custom_provider, diff --git a/surfsense_backend/app/services/quality_score.py b/surfsense_backend/app/services/quality_score.py index 2fb37de21..95484439b 100644 --- a/surfsense_backend/app/services/quality_score.py +++ b/surfsense_backend/app/services/quality_score.py @@ -108,25 +108,23 @@ PROVIDER_PRESTIGE_OR: dict[str, int] = { # YAML provider field (the upstream API shape the operator selected). PROVIDER_PRESTIGE_YAML: dict[str, int] = { - "AZURE_OPENAI": 50, - "OPENAI": 50, - "ANTHROPIC": 50, - "GOOGLE": 50, - "VERTEX_AI": 50, - "GEMINI": 50, - "XAI": 50, - "MISTRAL": 38, - "DEEPSEEK": 38, - "COHERE": 38, - "GROQ": 30, - "TOGETHER_AI": 28, - "FIREWORKS_AI": 28, - "PERPLEXITY": 28, - "MINIMAX": 28, - "BEDROCK": 28, - "OPENROUTER": 25, - "OLLAMA": 12, - "CUSTOM": 12, + "azure": 50, + "openai": 50, + "anthropic": 50, + "gemini": 50, + "vertex_ai": 50, + "xai": 50, + "mistral": 38, + "deepseek": 38, + "cohere": 38, + "groq": 30, + "together_ai": 28, + "fireworks_ai": 28, + "perplexity": 28, + "bedrock": 28, + "openrouter": 25, + "ollama_chat": 12, + "custom": 12, } @@ -275,7 +273,7 @@ def static_score_yaml(cfg: dict) -> int: listed this model. Pricing / context fall through to lazy ``litellm`` lookups; failures are silent (we just lose those sub-points). """ - provider = str(cfg.get("provider", "")).upper() + provider = str(cfg.get("litellm_provider", "")).lower() base = PROVIDER_PRESTIGE_YAML.get(provider, 15) model_name = cfg.get("model_name") or "" diff --git a/surfsense_backend/tests/unit/services/test_openrouter_integration_service.py b/surfsense_backend/tests/unit/services/test_openrouter_integration_service.py index 88fcf2db3..9d4c1a04b 100644 --- a/surfsense_backend/tests/unit/services/test_openrouter_integration_service.py +++ b/surfsense_backend/tests/unit/services/test_openrouter_integration_service.py @@ -263,11 +263,10 @@ def test_generate_image_gen_configs_filters_by_image_output(): # Each config must carry ``billing_tier`` for routing in image_generation_routes. for c in cfgs: assert c["billing_tier"] in {"free", "premium"} - assert c["provider"] == "OPENROUTER" + assert c["litellm_provider"] == "openrouter" assert c[_OPENROUTER_DYNAMIC_MARKER] is True - # Defense-in-depth: emit the OpenRouter base URL at source so a - # downstream call site that forgets ``resolve_api_base`` still - # doesn't 404 against an inherited Azure endpoint. + # Emit the OpenRouter base URL at source so every call path passes an + # explicit api_base and cannot inherit a process-global endpoint. assert c["api_base"] == "https://openrouter.ai/api/v1" @@ -346,9 +345,8 @@ def test_generate_vision_llm_configs_filters_by_image_input_text_output(): assert cfg["input_cost_per_token"] == pytest.approx(5e-6) assert cfg["output_cost_per_token"] == pytest.approx(15e-6) assert cfg[_OPENROUTER_DYNAMIC_MARKER] is True - # Defense-in-depth: emit the OpenRouter base URL at source so a - # downstream call site that forgets ``resolve_api_base`` still - # doesn't inherit an Azure endpoint. + # Emit the OpenRouter base URL at source so every call path passes an + # explicit api_base and cannot inherit a process-global endpoint. assert cfg["api_base"] == "https://openrouter.ai/api/v1" diff --git a/surfsense_backend/tests/unit/services/test_pricing_registration.py b/surfsense_backend/tests/unit/services/test_pricing_registration.py index e97250ff2..c9adc6aac 100644 --- a/surfsense_backend/tests/unit/services/test_pricing_registration.py +++ b/surfsense_backend/tests/unit/services/test_pricing_registration.py @@ -186,7 +186,7 @@ def test_openrouter_models_register_under_aliases(monkeypatch): [ { "id": 1, - "provider": "OPENROUTER", + "litellm_provider": "openrouter", "model_name": "anthropic/claude-3-5-sonnet", } ], @@ -228,7 +228,7 @@ def test_yaml_override_registers_under_alias_set(monkeypatch): [ { "id": 1, - "provider": "AZURE_OPENAI", + "litellm_provider": "azure", "model_name": "gpt-5.4", "litellm_params": { "base_model": "gpt-5.4", @@ -243,7 +243,6 @@ def test_yaml_override_registers_under_alias_set(monkeypatch): keys = spy.all_keys assert "gpt-5.4" in keys - assert "azure_openai/gpt-5.4" in keys assert "azure/gpt-5.4" in keys payload = spy.calls[0] @@ -271,7 +270,7 @@ def test_no_override_means_no_registration(monkeypatch): [ { "id": 1, - "provider": "OPENAI", + "litellm_provider": "openai", "model_name": "gpt-4o", "litellm_params": {"base_model": "gpt-4o"}, } @@ -302,7 +301,7 @@ def test_openrouter_skipped_when_pricing_missing(monkeypatch): [ { "id": 1, - "provider": "OPENROUTER", + "litellm_provider": "openrouter", "model_name": "anthropic/claude-3-5-sonnet", } ], @@ -349,12 +348,12 @@ def test_register_continues_after_individual_failure(monkeypatch, caplog): [ { "id": 1, - "provider": "OPENROUTER", + "litellm_provider": "openrouter", "model_name": "anthropic/claude-3-5-sonnet", }, { "id": 2, - "provider": "OPENAI", + "litellm_provider": "openai", "model_name": "custom-deployment", "litellm_params": { "base_model": "custom-deployment", @@ -396,7 +395,7 @@ def test_vision_configs_registered_with_chat_shape(monkeypatch): [ { "id": -1, - "provider": "OPENROUTER", + "litellm_provider": "openrouter", "model_name": "openai/gpt-4o", "billing_tier": "premium", "input_cost_per_token": 5e-6, @@ -433,7 +432,7 @@ def test_vision_with_inline_pricing_when_or_cache_missing(monkeypatch): [ { "id": -1, - "provider": "OPENROUTER", + "litellm_provider": "openrouter", "model_name": "google/gemini-2.5-flash", "billing_tier": "premium", "input_cost_per_token": 1e-6, diff --git a/surfsense_backend/tests/unit/services/test_provider_api_base.py b/surfsense_backend/tests/unit/services/test_provider_api_base.py deleted file mode 100644 index 12cd0a3d5..000000000 --- a/surfsense_backend/tests/unit/services/test_provider_api_base.py +++ /dev/null @@ -1,107 +0,0 @@ -"""Unit tests for the shared ``api_base`` resolver. - -The cascade exists so vision and image-gen call sites can't silently -inherit ``litellm.api_base`` (commonly set by ``AZURE_OPENAI_ENDPOINT``) -when an OpenRouter / Groq / etc. config ships an empty string. See -``provider_api_base`` module docstring for the original repro -(OpenRouter image-gen 404-ing against an Azure endpoint). -""" - -from __future__ import annotations - -import pytest - -from app.services.provider_api_base import ( - PROVIDER_DEFAULT_API_BASE, - PROVIDER_KEY_DEFAULT_API_BASE, - resolve_api_base, -) - -pytestmark = pytest.mark.unit - - -def test_config_value_wins_over_defaults(): - """A non-empty config value is always returned verbatim, even when the - provider has a default — the operator gets the last word.""" - result = resolve_api_base( - provider="OPENROUTER", - provider_prefix="openrouter", - config_api_base="https://my-openrouter-mirror.example.com/v1", - ) - assert result == "https://my-openrouter-mirror.example.com/v1" - - -def test_provider_key_default_when_config_missing(): - """``DEEPSEEK`` shares the ``openai`` LiteLLM prefix but has its own - base URL — the provider-key map must take precedence over the prefix - map so DeepSeek requests don't go to OpenAI.""" - result = resolve_api_base( - provider="DEEPSEEK", - provider_prefix="openai", - config_api_base=None, - ) - assert result == PROVIDER_KEY_DEFAULT_API_BASE["DEEPSEEK"] - - -def test_provider_prefix_default_when_no_key_default(): - result = resolve_api_base( - provider="OPENROUTER", - provider_prefix="openrouter", - config_api_base=None, - ) - assert result == PROVIDER_DEFAULT_API_BASE["openrouter"] - - -def test_unknown_provider_returns_none(): - """When neither map matches we return ``None`` so the caller can let - LiteLLM apply its own provider-integration default (Azure deployment - URL, custom-provider URL, etc.).""" - result = resolve_api_base( - provider="SOMETHING_NEW", - provider_prefix="something_new", - config_api_base=None, - ) - assert result is None - - -def test_empty_string_config_treated_as_missing(): - """The original bug: OpenRouter dynamic configs ship ``api_base=""`` - and downstream call sites use ``if cfg.get("api_base"):`` — empty - strings are falsy in Python but the cascade has to step in anyway.""" - result = resolve_api_base( - provider="OPENROUTER", - provider_prefix="openrouter", - config_api_base="", - ) - assert result == PROVIDER_DEFAULT_API_BASE["openrouter"] - - -def test_whitespace_only_config_treated_as_missing(): - """A config value of ``" "`` is a configuration mistake — treat it - as missing instead of forwarding whitespace to LiteLLM (which would - almost certainly 404).""" - result = resolve_api_base( - provider="OPENROUTER", - provider_prefix="openrouter", - config_api_base=" ", - ) - assert result == PROVIDER_DEFAULT_API_BASE["openrouter"] - - -def test_provider_case_insensitive(): - """Some call sites pass the provider lowercase (DB enum value), others - uppercase (YAML key). Both must resolve.""" - upper = resolve_api_base( - provider="DEEPSEEK", provider_prefix="openai", config_api_base=None - ) - lower = resolve_api_base( - provider="deepseek", provider_prefix="openai", config_api_base=None - ) - assert upper == lower == PROVIDER_KEY_DEFAULT_API_BASE["DEEPSEEK"] - - -def test_all_inputs_none_returns_none(): - assert ( - resolve_api_base(provider=None, provider_prefix=None, config_api_base=None) - is None - ) diff --git a/surfsense_backend/tests/unit/services/test_provider_capabilities.py b/surfsense_backend/tests/unit/services/test_provider_capabilities.py index aac88977f..c327c2c87 100644 --- a/surfsense_backend/tests/unit/services/test_provider_capabilities.py +++ b/surfsense_backend/tests/unit/services/test_provider_capabilities.py @@ -32,7 +32,7 @@ pytestmark = pytest.mark.unit def test_or_modalities_with_image_returns_true(): assert ( derive_supports_image_input( - provider="OPENROUTER", + litellm_provider="openrouter", model_name="openai/gpt-4o", openrouter_input_modalities=["text", "image"], ) @@ -43,7 +43,7 @@ def test_or_modalities_with_image_returns_true(): def test_or_modalities_text_only_returns_false(): assert ( derive_supports_image_input( - provider="OPENROUTER", + litellm_provider="openrouter", model_name="deepseek/deepseek-v3.2-exp", openrouter_input_modalities=["text"], ) @@ -57,7 +57,7 @@ def test_or_modalities_empty_list_returns_false(): to LiteLLM.""" assert ( derive_supports_image_input( - provider="OPENROUTER", + litellm_provider="openrouter", model_name="weird/empty-modalities", openrouter_input_modalities=[], ) @@ -70,7 +70,7 @@ def test_or_modalities_none_falls_through_to_litellm(): to LiteLLM. Using ``openai/gpt-4o`` which is in LiteLLM's map.""" assert ( derive_supports_image_input( - provider="OPENAI", + litellm_provider="openai", model_name="gpt-4o", openrouter_input_modalities=None, ) @@ -86,7 +86,7 @@ def test_or_modalities_none_falls_through_to_litellm(): def test_litellm_known_vision_model_returns_true(): assert ( derive_supports_image_input( - provider="OPENAI", + litellm_provider="openai", model_name="gpt-4o", ) is True @@ -100,7 +100,7 @@ def test_litellm_base_model_wins_over_model_name(): doesn't know) would shadow the real capability.""" assert ( derive_supports_image_input( - provider="AZURE_OPENAI", + litellm_provider="azure", model_name="my-azure-deployment-id", base_model="gpt-4o", ) @@ -112,7 +112,7 @@ def test_litellm_unknown_model_default_allows(): """Default-allow on unknown — the safety net is the actual block.""" assert ( derive_supports_image_input( - provider="CUSTOM", + litellm_provider="custom", model_name="brand-new-model-x9-unmapped", custom_provider="brand_new_proxy", ) @@ -128,7 +128,7 @@ def test_litellm_known_text_only_returns_false(): # Sanity: confirm the helper's negative path. We use a small model # known not to support vision per the map. result = derive_supports_image_input( - provider="DEEPSEEK", + litellm_provider="openai", model_name="deepseek-chat", ) # We accept either False (LiteLLM said explicit no) or True @@ -147,7 +147,7 @@ def test_litellm_known_text_only_returns_false(): def test_is_known_text_only_returns_false_for_vision_model(): assert ( is_known_text_only_chat_model( - provider="OPENAI", + litellm_provider="openai", model_name="gpt-4o", ) is False @@ -160,7 +160,7 @@ def test_is_known_text_only_returns_false_for_unknown_model(): fixing.""" assert ( is_known_text_only_chat_model( - provider="CUSTOM", + litellm_provider="custom", model_name="brand-new-model-x9-unmapped", custom_provider="brand_new_proxy", ) @@ -181,7 +181,7 @@ def test_is_known_text_only_returns_false_when_lookup_raises(monkeypatch): assert ( is_known_text_only_chat_model( - provider="OPENAI", + litellm_provider="openai", model_name="gpt-4o", ) is False @@ -201,7 +201,7 @@ def test_is_known_text_only_returns_true_on_explicit_false(monkeypatch): assert ( is_known_text_only_chat_model( - provider="OPENAI", + litellm_provider="openai", model_name="any-model", ) is True @@ -218,7 +218,7 @@ def test_is_known_text_only_returns_false_on_supports_vision_true(monkeypatch): assert ( is_known_text_only_chat_model( - provider="OPENAI", + litellm_provider="openai", model_name="any-model", ) is False @@ -237,7 +237,7 @@ def test_is_known_text_only_returns_false_on_missing_key(monkeypatch): assert ( is_known_text_only_chat_model( - provider="OPENAI", + litellm_provider="openai", model_name="any-model", ) is False diff --git a/surfsense_backend/tests/unit/services/test_quality_score.py b/surfsense_backend/tests/unit/services/test_quality_score.py index 6fbc8fd62..369c8b8f3 100644 --- a/surfsense_backend/tests/unit/services/test_quality_score.py +++ b/surfsense_backend/tests/unit/services/test_quality_score.py @@ -228,7 +228,7 @@ def test_static_score_or_recent_release_beats_year_old_same_provider(): def test_static_score_yaml_includes_operator_bonus(): cfg = { - "provider": "AZURE_OPENAI", + "litellm_provider": "azure", "model_name": "gpt-5", "litellm_params": {"base_model": "azure/gpt-5"}, } @@ -238,7 +238,7 @@ def test_static_score_yaml_includes_operator_bonus(): def test_static_score_yaml_unknown_provider_still_carries_bonus(): cfg = { - "provider": "SOME_NEW_PROVIDER", + "litellm_provider": "some_new_provider", "model_name": "weird-model", } score = static_score_yaml(cfg) @@ -247,7 +247,7 @@ def test_static_score_yaml_unknown_provider_still_carries_bonus(): def test_static_score_yaml_clamped_0_to_100(): cfg = { - "provider": "AZURE_OPENAI", + "litellm_provider": "azure", "model_name": "gpt-5", "litellm_params": {"base_model": "azure/gpt-5"}, } From c28c4f5785bb0bbdfe8997ef157f8fedc3276d72 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Thu, 11 Jun 2026 18:22:23 +0530 Subject: [PATCH 19/59] feat(chat): route models by provider capabilities --- .../app/agents/chat/runtime/llm_config.py | 53 ++--- .../app/routes/anonymous_chat_routes.py | 4 +- .../app/routes/vision_llm_routes.py | 2 +- .../app/services/auto_model_pin_service.py | 206 ++++++++++++++---- .../app/services/llm_router_service.py | 10 +- surfsense_backend/app/services/llm_service.py | 75 +++---- .../app/services/model_list_service.py | 10 +- .../app/services/model_resolver.py | 110 +++------- .../app/services/vision_llm_router_service.py | 14 +- .../app/services/vision_model_list_service.py | 8 +- .../flows/new_chat/llm_capability.py | 2 +- .../streaming/flows/new_chat/title_gen.py | 17 +- .../chat/streaming/flows/shared/llm_bundle.py | 137 ++++++++++-- .../services/test_auto_model_pin_service.py | 68 +++--- .../services/test_auto_pin_image_aware.py | 6 +- .../services/test_llm_router_pool_filter.py | 8 +- .../services/test_or_health_enrichment.py | 6 +- .../test_stream_new_chat_image_safety_net.py | 12 +- 18 files changed, 429 insertions(+), 319 deletions(-) diff --git a/surfsense_backend/app/agents/chat/runtime/llm_config.py b/surfsense_backend/app/agents/chat/runtime/llm_config.py index 03d7f548e..b9344e001 100644 --- a/surfsense_backend/app/agents/chat/runtime/llm_config.py +++ b/surfsense_backend/app/agents/chat/runtime/llm_config.py @@ -2,9 +2,9 @@ LLM configuration utilities for SurfSense agents. This module provides functions for loading LLM configurations from: -1. Auto mode (ID 0) - Uses LiteLLM Router for load balancing +1. Auto mode (ID 0) - Resolved by callers to a concrete model-connection model 2. YAML files (global configs with negative IDs) -3. Database NewLLMConfig table (user-created configs with positive IDs) +3. Database model-connections table (user-created configs with positive IDs) It also provides utilities for creating ChatLiteLLM instances and managing prompt configurations. @@ -33,9 +33,7 @@ from app.agents.chat.runtime.prompt_caching import ( from app.services.llm_router_service import ( AUTO_MODE_ID, ChatLiteLLMRouter, - LLMRouterService, _sanitize_content, - get_auto_mode_llm, is_auto_mode, ) @@ -92,14 +90,6 @@ class SanitizedChatLiteLLM(ChatLiteLLM): yield chunk -# Re-exported under the historical name ``PROVIDER_MAP``. Source of truth lives -# in provider_capabilities so the YAML loader can resolve prefixes during -# app.config init without importing the agent/tools tree. -from app.services.provider_capabilities import ( # noqa: E402 - _PROVIDER_PREFIX_MAP as PROVIDER_MAP, -) - - def _attach_model_profile(llm: ChatLiteLLM, model_string: str) -> None: """Attach a ``profile`` dict to ChatLiteLLM with model context metadata.""" try: @@ -122,7 +112,8 @@ class AgentConfig: Complete configuration for the SurfSense agent. This combines LLM settings with prompt configuration from NewLLMConfig. - Supports Auto mode (ID 0) which uses LiteLLM Router for load balancing. + Supports Auto mode metadata (ID 0). Runtime callers must resolve Auto to + a concrete global or BYOK model before constructing ChatLiteLLM. """ # LLM Model Settings @@ -219,7 +210,7 @@ class AgentConfig: # BYOK rows have no curated flag; ask LiteLLM (default-allow on # unknown). The streaming safety net still blocks explicit text-only. supports_image_input=derive_supports_image_input( - provider=provider_value, + litellm_provider=provider_value.lower(), model_name=config.model_name, base_model=base_model, custom_provider=config.custom_provider, @@ -238,7 +229,7 @@ class AgentConfig: system_instructions = yaml_config.get("system_instructions", "") - provider = yaml_config.get("provider", "").upper() + provider = yaml_config.get("litellm_provider", "") model_name = yaml_config.get("model_name", "") custom_provider = yaml_config.get("custom_provider") litellm_params = yaml_config.get("litellm_params") or {} @@ -254,7 +245,7 @@ class AgentConfig: supports_image_input = bool(yaml_config.get("supports_image_input")) else: supports_image_input = derive_supports_image_input( - provider=provider, + litellm_provider=provider, model_name=model_name, base_model=base_model, custom_provider=custom_provider, @@ -383,9 +374,6 @@ async def load_agent_config( ) -> "AgentConfig | None": """Main config loader: id 0 -> Auto mode; negative -> YAML; positive -> DB.""" if is_auto_mode(config_id): - if not LLMRouterService.is_initialized(): - print("Error: Auto mode requested but LLM Router not initialized") - return None return AgentConfig.from_auto_mode() if config_id < 0: @@ -408,9 +396,8 @@ def create_chat_litellm_from_config(llm_config: dict) -> ChatLiteLLM | None: if llm_config.get("custom_provider"): model_string = f"{llm_config['custom_provider']}/{llm_config['model_name']}" else: - provider = llm_config.get("provider", "").upper() - provider_prefix = PROVIDER_MAP.get(provider, provider.lower()) - model_string = f"{provider_prefix}/{llm_config['model_name']}" + litellm_provider = llm_config.get("litellm_provider", "openai") + model_string = f"{litellm_provider}/{llm_config['model_name']}" litellm_kwargs = { "model": model_string, @@ -433,29 +420,15 @@ def create_chat_litellm_from_config(llm_config: dict) -> ChatLiteLLM | None: def create_chat_litellm_from_agent_config( agent_config: AgentConfig, ) -> ChatLiteLLM | ChatLiteLLMRouter | None: - """Create a ChatLiteLLM (or, for Auto mode, a load-balancing router) from config.""" + """Create a ChatLiteLLM from an already resolved concrete model config.""" if agent_config.is_auto_mode: - if not LLMRouterService.is_initialized(): - print("Error: Auto mode requested but LLM Router not initialized") - return None - try: - router_llm = get_auto_mode_llm() - if router_llm is not None: - # Universal injection points only: auto-mode fans out across - # providers, so provider-specific kwargs have no known target. - apply_litellm_prompt_caching(router_llm, agent_config=agent_config) - return router_llm - except Exception as e: - print(f"Error creating ChatLiteLLMRouter: {e}") - return None + print("Error: Auto mode must be resolved to a concrete model before LLM creation") + return None if agent_config.custom_provider: model_string = f"{agent_config.custom_provider}/{agent_config.model_name}" else: - provider_prefix = PROVIDER_MAP.get( - agent_config.provider, agent_config.provider.lower() - ) - model_string = f"{provider_prefix}/{agent_config.model_name}" + model_string = f"{agent_config.provider}/{agent_config.model_name}" litellm_kwargs = { "model": model_string, diff --git a/surfsense_backend/app/routes/anonymous_chat_routes.py b/surfsense_backend/app/routes/anonymous_chat_routes.py index ad3277375..aba1a3a12 100644 --- a/surfsense_backend/app/routes/anonymous_chat_routes.py +++ b/surfsense_backend/app/routes/anonymous_chat_routes.py @@ -132,7 +132,7 @@ async def list_anonymous_models(): id=cfg.get("id", 0), name=cfg.get("name", ""), description=cfg.get("description"), - provider=cfg.get("provider", ""), + provider=cfg.get("litellm_provider", ""), model_name=cfg.get("model_name", ""), billing_tier=cfg.get("billing_tier", "free"), is_premium=cfg.get("billing_tier", "free") == "premium", @@ -161,7 +161,7 @@ async def get_anonymous_model(slug: str): id=cfg.get("id", 0), name=cfg.get("name", ""), description=cfg.get("description"), - provider=cfg.get("provider", ""), + provider=cfg.get("litellm_provider", ""), model_name=cfg.get("model_name", ""), billing_tier=cfg.get("billing_tier", "free"), is_premium=cfg.get("billing_tier", "free") == "premium", diff --git a/surfsense_backend/app/routes/vision_llm_routes.py b/surfsense_backend/app/routes/vision_llm_routes.py index e4f08f604..df218daac 100644 --- a/surfsense_backend/app/routes/vision_llm_routes.py +++ b/surfsense_backend/app/routes/vision_llm_routes.py @@ -96,7 +96,7 @@ async def get_global_vision_llm_configs( "id": cfg.get("id"), "name": cfg.get("name"), "description": cfg.get("description"), - "provider": cfg.get("provider"), + "provider": cfg.get("litellm_provider"), "custom_provider": cfg.get("custom_provider"), "model_name": cfg.get("model_name"), "api_base": cfg.get("api_base") or None, diff --git a/surfsense_backend/app/services/auto_model_pin_service.py b/surfsense_backend/app/services/auto_model_pin_service.py index 9bbca8669..ee8c4b8dc 100644 --- a/surfsense_backend/app/services/auto_model_pin_service.py +++ b/surfsense_backend/app/services/auto_model_pin_service.py @@ -23,9 +23,10 @@ from uuid import UUID from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import selectinload from app.config import config -from app.db import NewChatThread +from app.db import Connection, Model, NewChatThread from app.services.quality_score import _QUALITY_TOP_K from app.services.token_quota_service import TokenQuotaService @@ -61,11 +62,20 @@ def _is_usable_global_config(cfg: dict) -> bool: return bool( cfg.get("id") is not None and cfg.get("model_name") - and cfg.get("provider") + and cfg.get("litellm_provider") and cfg.get("api_key") ) +def _has_capability(model: dict | Model, capability: str) -> bool: + caps = ( + model.get("capabilities", {}) + if isinstance(model, dict) + else model.capabilities or {} + ) + return bool(caps.get(capability)) + + def _prune_runtime_cooldowns(now_ts: float | None = None) -> None: now = time.time() if now_ts is None else now_ts stale = [cid for cid, until in _runtime_cooldown_until.items() if until <= now] @@ -186,15 +196,19 @@ def _cfg_supports_image_input(cfg: dict) -> bool: else None ) return derive_supports_image_input( - provider=cfg.get("provider"), + litellm_provider=cfg.get("litellm_provider"), model_name=cfg.get("model_name"), base_model=base_model, custom_provider=cfg.get("custom_provider"), ) -def _global_candidates(*, requires_image_input: bool = False) -> list[dict]: - """Return Auto-eligible global cfgs. +def _global_candidates( + *, + capability: str = "chat", + requires_image_input: bool = False, +) -> list[dict]: + """Return Auto-eligible global virtual models. Drops cfgs flagged ``health_gated`` (best non-null OpenRouter uptime below ``_HEALTH_GATE_UPTIME_PCT``) so chronically broken providers @@ -205,17 +219,135 @@ def _global_candidates(*, requires_image_input: bool = False) -> list[dict]: filters out configs whose ``supports_image_input`` resolves to False so a text-only deployment can't be pinned for an image request. """ - candidates = [ - cfg + connection_by_id = { + int(conn.get("id")): conn + for conn in config.GLOBAL_CONNECTIONS + if conn.get("id") is not None + } + config_by_model_name = { + cfg.get("model_name"): cfg for cfg in config.GLOBAL_LLM_CONFIGS if _is_usable_global_config(cfg) - and not cfg.get("health_gated") - and not _is_runtime_cooled_down(int(cfg.get("id", 0))) - and (not requires_image_input or _cfg_supports_image_input(cfg)) - ] + } + candidates: list[dict] = [] + for model in config.GLOBAL_MODELS: + model_id = int(model.get("id", 0)) + if model_id >= 0 or _is_runtime_cooled_down(model_id): + continue + if not _has_capability(model, capability): + continue + cfg = config_by_model_name.get(model.get("model_id")) or {} + if cfg.get("health_gated"): + continue + if requires_image_input and not _has_capability(model, "vision"): + continue + if requires_image_input and cfg and not _cfg_supports_image_input(cfg): + continue + connection = connection_by_id.get(int(model.get("connection_id", 0))) + if not connection: + continue + catalog = model.get("catalog") or {} + candidates.append( + { + "id": model_id, + "model_id": model.get("model_id"), + "source": "global", + "connection": connection, + "capabilities": model.get("capabilities") or {}, + "billing_tier": model.get("billing_tier", "free"), + "litellm_provider": connection.get("litellm_provider"), + "model_name": model.get("model_id"), + "auto_pin_tier": catalog.get("auto_pin_tier") + or cfg.get("auto_pin_tier") + or "A", + "quality_score": catalog.get("quality_score") + or cfg.get("quality_score") + or cfg.get("quality_score_static") + or 50, + } + ) return sorted(candidates, key=lambda c: int(c.get("id", 0))) +async def _db_candidates( + session: AsyncSession, + *, + search_space_id: int, + user_id: str | UUID | None, + capability: str, + requires_image_input: bool = False, +) -> list[dict]: + parsed_user_id = _to_uuid(user_id) + stmt = ( + select(Model) + .options(selectinload(Model.connection)) + .join(Connection, Model.connection_id == Connection.id) + .where(Model.enabled.is_(True), Connection.enabled.is_(True)) + ) + result = await session.execute(stmt) + candidates: list[dict] = [] + for model in result.scalars().all(): + conn = model.connection + if not conn: + continue + if conn.search_space_id is not None and conn.search_space_id != search_space_id: + continue + if conn.user_id is not None and parsed_user_id is not None and conn.user_id != parsed_user_id: + continue + if conn.user_id is not None and parsed_user_id is None: + continue + if not _has_capability(model, capability): + continue + if requires_image_input and not _has_capability(model, "vision"): + continue + model_id = int(model.id) + if _is_runtime_cooled_down(model_id): + continue + catalog = model.catalog or {} + candidates.append( + { + "id": model_id, + "model_id": model.model_id, + "source": "db", + "connection": conn, + "capabilities": model.capabilities or {}, + "billing_tier": "byok", + "litellm_provider": conn.litellm_provider, + "model_name": model.model_id, + "auto_pin_tier": catalog.get("auto_pin_tier") or "BYOK", + "quality_score": catalog.get("quality_score") or 75, + } + ) + return sorted(candidates, key=lambda c: int(c.get("id", 0))) + + +async def auto_model_candidates( + session: AsyncSession, + *, + search_space_id: int, + user_id: str | UUID | None, + capability: str, + requires_image_input: bool = False, + exclude_model_ids: set[int] | None = None, +) -> list[dict]: + excluded_ids = {int(mid) for mid in (exclude_model_ids or set())} + db_candidates = await _db_candidates( + session, + search_space_id=search_space_id, + user_id=user_id, + capability=capability, + requires_image_input=requires_image_input, + ) + candidates = [ + *_global_candidates( + capability=capability, + requires_image_input=requires_image_input, + ), + *db_candidates, + ] + return [c for c in candidates if int(c.get("id", 0)) not in excluded_ids] + + def _tier_of(cfg: dict) -> str: return str(cfg.get("billing_tier", "free")).lower() @@ -223,8 +355,9 @@ def _tier_of(cfg: dict) -> str: def _is_preferred_premium_auto_config(cfg: dict) -> bool: """Return True for the operator-preferred premium Auto model.""" return ( - _tier_of(cfg) == "premium" - and str(cfg.get("provider", "")).upper() == "AZURE_OPENAI" + cfg.get("source") == "global" + and _tier_of(cfg) == "premium" + and str(cfg.get("litellm_provider", "")).lower() == "azure" and str(cfg.get("model_name", "")).lower() == "gpt-5.4" ) @@ -251,6 +384,11 @@ def _select_pin(eligible: list[dict], thread_id: int) -> tuple[dict, int]: return top_k[idx], len(top_k) +def choose_auto_model_candidate(candidates: list[dict], seed_id: int) -> dict: + selected, _ = _select_pin(candidates, seed_id) + return selected + + def _to_uuid(user_id: str | UUID | None) -> UUID | None: if user_id is None: return None @@ -326,20 +464,23 @@ async def resolve_or_get_pinned_llm_config_id( ) excluded_ids = {int(cid) for cid in (exclude_config_ids or set())} - candidates = [ - c - for c in _global_candidates(requires_image_input=requires_image_input) - if int(c.get("id", 0)) not in excluded_ids - ] + candidates = await auto_model_candidates( + session, + search_space_id=search_space_id, + user_id=user_id, + capability="chat", + requires_image_input=requires_image_input, + exclude_model_ids=excluded_ids, + ) if not candidates: if requires_image_input: # Distinguish the "no vision-capable cfg" case from generic # "no usable cfg" so the streaming task can map this to the # MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT SSE error. raise ValueError( - "No vision-capable global LLM configs are available for Auto mode" + "No vision-capable LLM models are available for Auto mode" ) - raise ValueError("No usable global LLM configs are available for Auto mode") + raise ValueError("No usable LLM models are available for Auto mode") candidate_by_id = {int(c["id"]): c for c in candidates} # Reuse an existing valid pin without re-checking current quota (no silent @@ -379,24 +520,13 @@ async def resolve_or_get_pinned_llm_config_id( # log that explicitly so operators can correlate the re-pin with # the user's image attachment instead of suspecting a cooldown. if requires_image_input: - try: - pinned_global = next( - c - for c in config.GLOBAL_LLM_CONFIGS - if int(c.get("id", 0)) == int(pinned_id) - ) - except StopIteration: - pinned_global = None - if pinned_global is not None and not _cfg_supports_image_input( - pinned_global - ): - logger.info( - "auto_pin_repinned_for_image thread_id=%s search_space_id=%s " - "previous_config_id=%s", - thread_id, - search_space_id, - pinned_id, - ) + logger.info( + "auto_pin_repinned_for_image thread_id=%s search_space_id=%s " + "previous_config_id=%s", + thread_id, + search_space_id, + pinned_id, + ) logger.info( "auto_pin_invalid thread_id=%s search_space_id=%s pinned_config_id=%s", thread_id, diff --git a/surfsense_backend/app/services/llm_router_service.py b/surfsense_backend/app/services/llm_router_service.py index 69feb30eb..a151a0d6e 100644 --- a/surfsense_backend/app/services/llm_router_service.py +++ b/surfsense_backend/app/services/llm_router_service.py @@ -30,11 +30,7 @@ from litellm.exceptions import ( ) from pydantic import Field -from app.services.model_resolver import ( - NATIVE_PROVIDER_PREFIX, - native_connection_from_config, - to_litellm, -) +from app.services.model_resolver import native_connection_from_config, to_litellm from app.utils.perf import get_perf_logger litellm.json_logs = False @@ -101,10 +97,6 @@ def _sanitize_content(content: Any) -> Any: # Special ID for Auto mode - uses router for load balancing AUTO_MODE_ID = 0 -# Historical export kept for callers that still import PROVIDER_MAP. -PROVIDER_MAP = NATIVE_PROVIDER_PREFIX - - class LLMRouterService: """ Singleton service for managing LiteLLM Router. diff --git a/surfsense_backend/app/services/llm_service.py b/surfsense_backend/app/services/llm_service.py index 75451d01f..86a9c8556 100644 --- a/surfsense_backend/app/services/llm_service.py +++ b/surfsense_backend/app/services/llm_service.py @@ -10,13 +10,11 @@ from sqlalchemy.orm import selectinload from app.config import config from app.db import Model, SearchSpace -from app.services.llm_router_service import ( - AUTO_MODE_ID, - ChatLiteLLMRouter, - LLMRouterService, - get_auto_mode_llm, - is_auto_mode, +from app.services.auto_model_pin_service import ( + auto_model_candidates, + choose_auto_model_candidate, ) +from app.services.llm_router_service import AUTO_MODE_ID, ChatLiteLLMRouter, is_auto_mode from app.services.model_resolver import native_connection_from_config, to_litellm from app.services.token_tracking_service import token_tracker @@ -78,7 +76,7 @@ def _legacy_config_connection( api_version: str | None = None, ) -> tuple[str, dict]: cfg = { - "provider": provider, + "litellm_provider": provider.lower(), "model_name": model_name, "api_key": api_key, "api_base": api_base, @@ -325,23 +323,21 @@ async def get_search_space_llm_instance( logger.error(f"No {role} LLM configured for search space {search_space_id}") return None - # Check for Auto mode (ID 0) - use router for load balancing + # Auto mode resolves to one concrete global or BYOK model from the + # unified model-connections catalog. if is_auto_mode(llm_config_id): - if not LLMRouterService.is_initialized(): - logger.error( - "Auto mode requested but LLM Router not initialized. " - "Ensure global_llm_config.yaml exists with valid configs." - ) - return None - - try: - logger.debug( - f"Using Auto mode (LLM Router) for search space {search_space_id}, role {role}" - ) - return get_auto_mode_llm(streaming=not disable_streaming) - except Exception as e: - logger.error(f"Failed to create ChatLiteLLMRouter: {e}") + candidates = await auto_model_candidates( + session, + search_space_id=search_space_id, + user_id=search_space.user_id, + capability="chat", + ) + if not candidates: + logger.error("No chat-capable models available for Auto mode") return None + llm_config_id = int( + choose_auto_model_candidate(candidates, search_space_id)["id"] + ) # Check if this is a global virtual model (negative ID) if llm_config_id < 0: @@ -414,7 +410,7 @@ async def get_vision_llm( """Get the search space's vision LLM instance for screenshot analysis. Resolves from the new connection/model role bindings: - - Auto mode (ID 0): VisionLLMRouterService + - Auto mode (ID 0): unified global/BYOK model candidate selection - Global (negative ID): virtual GLOBAL models from YAML - DB (positive ID): Model + Connection tables @@ -424,10 +420,7 @@ async def get_vision_llm( unwrapped — they don't consume premium credit (issue M). """ from app.services.quota_checked_vision_llm import QuotaCheckedVisionLLM - from app.services.vision_llm_router_service import ( - VisionLLMRouterService, - is_vision_auto_mode, - ) + from app.services.vision_llm_router_service import is_vision_auto_mode try: result = await session.execute( @@ -476,26 +469,16 @@ async def get_vision_llm( return None if is_vision_auto_mode(config_id): - if not VisionLLMRouterService.is_initialized(): - logger.error( - "Vision Auto mode requested but Vision LLM Router not initialized" - ) - return None - try: - # Auto mode is currently treated as free at the wrapper - # level — the underlying router can dispatch to either - # premium or free YAML configs but routing decisions are - # opaque. If/when we want to bill Auto-routed vision - # calls we'd need to thread the resolved deployment's - # billing_tier back from the router. For now we keep - # parity with chat Auto, which also doesn't pre-classify. - return ChatLiteLLMRouter( - router=VisionLLMRouterService.get_router(), - streaming=True, - ) - except Exception as e: - logger.error(f"Failed to create vision ChatLiteLLMRouter: {e}") + candidates = await auto_model_candidates( + session, + search_space_id=search_space_id, + user_id=owner_user_id, + capability="vision", + ) + if not candidates: + logger.error("No vision-capable models available for Auto mode") return None + config_id = int(choose_auto_model_candidate(candidates, search_space_id)["id"]) if config_id < 0: global_model = get_global_model(config_id) diff --git a/surfsense_backend/app/services/model_list_service.py b/surfsense_backend/app/services/model_list_service.py index 33837a8a0..1ef0b0c90 100644 --- a/surfsense_backend/app/services/model_list_service.py +++ b/surfsense_backend/app/services/model_list_service.py @@ -154,19 +154,19 @@ def _process_models(raw_models: list[dict]) -> list[dict]: } ) - # 2) Emit for the native provider when we have a mapping - native_provider = OPENROUTER_SLUG_TO_PROVIDER.get(provider_slug) - if native_provider: + # 2) Emit for the direct provider when we have a mapping + direct_provider = OPENROUTER_SLUG_TO_PROVIDER.get(provider_slug) + if direct_provider: # Google's Gemini API only serves gemini-* models. # Open-source models like gemma-* are NOT available through it. - if native_provider == "GOOGLE" and not model_name.startswith("gemini-"): + if direct_provider == "GOOGLE" and not model_name.startswith("gemini-"): continue processed.append( { "value": model_name, "label": name, - "provider": native_provider, + "provider": direct_provider, "context_window": context_window, } ) diff --git a/surfsense_backend/app/services/model_resolver.py b/surfsense_backend/app/services/model_resolver.py index ec485a5ae..ffa77a9a2 100644 --- a/surfsense_backend/app/services/model_resolver.py +++ b/surfsense_backend/app/services/model_resolver.py @@ -9,53 +9,12 @@ from __future__ import annotations from collections.abc import Mapping from typing import TYPE_CHECKING, Any -from app.services.provider_api_base import resolve_api_base - if TYPE_CHECKING: from app.db import Connection PROTOCOL_OLLAMA = "OLLAMA" PROTOCOL_OPENAI_COMPATIBLE = "OPENAI_COMPATIBLE" -PROTOCOL_NATIVE = "NATIVE" - -NATIVE_PROVIDER_PREFIX: dict[str, str] = { - "OPENAI": "openai", - "ANTHROPIC": "anthropic", - "GROQ": "groq", - "COHERE": "cohere", - "GOOGLE": "gemini", - "MISTRAL": "mistral", - "AZURE_OPENAI": "azure", - "AZURE": "azure", - "OPENROUTER": "openrouter", - "COMETAPI": "cometapi", - "XAI": "xai", - "BEDROCK": "bedrock", - "AWS_BEDROCK": "bedrock", - "VERTEX_AI": "vertex_ai", - "TOGETHER_AI": "together_ai", - "FIREWORKS_AI": "fireworks_ai", - "DEEPSEEK": "openai", - "ALIBABA_QWEN": "openai", - "MOONSHOT": "openai", - "ZHIPU": "openai", - "GITHUB_MODELS": "github", - "REPLICATE": "replicate", - "PERPLEXITY": "perplexity", - "ANYSCALE": "anyscale", - "DEEPINFRA": "deepinfra", - "CEREBRAS": "cerebras", - "SAMBANOVA": "sambanova", - "AI21": "ai21", - "CLOUDFLARE": "cloudflare", - "DATABRICKS": "databricks", - "HUGGINGFACE": "huggingface", - "MINIMAX": "openai", - "RECRAFT": "recraft", - "XINFERENCE": "xinference", - "NSCALE": "nscale", - "CUSTOM": "custom", -} +PROTOCOL_ANTHROPIC = "ANTHROPIC" def ensure_v1(base_url: str | None) -> str | None: @@ -77,6 +36,23 @@ def _protocol_value(protocol: Any) -> str: return getattr(protocol, "value", str(protocol)) +def default_litellm_provider(protocol: Any) -> str: + protocol_value = _protocol_value(protocol) + defaults = { + PROTOCOL_OLLAMA: "ollama_chat", + PROTOCOL_ANTHROPIC: "anthropic", + PROTOCOL_OPENAI_COMPATIBLE: "openai", + } + return defaults.get(protocol_value, "openai") + + +def _execution_api_base(protocol: str, base_url: str | None) -> str | None: + del protocol + if not base_url: + return None + return base_url.rstrip("/") + + def to_litellm( conn: Connection | Mapping[str, Any], model_id: str, @@ -85,38 +61,19 @@ def to_litellm( protocol = _protocol_value(_conn_value(conn, "protocol")) base_url = _conn_value(conn, "base_url") api_key = _conn_value(conn, "api_key") - native_provider = _conn_value(conn, "native_provider") + litellm_provider = ( + _conn_value(conn, "litellm_provider") or default_litellm_provider(protocol) + ) extra = _conn_value(conn, "extra") or {} kwargs: dict[str, Any] = {} if api_key: kwargs["api_key"] = api_key - if protocol == PROTOCOL_OLLAMA: - model_string = f"ollama_chat/{model_id}" - if base_url: - kwargs["api_base"] = base_url.rstrip("/") - elif protocol == PROTOCOL_OPENAI_COMPATIBLE: - model_string = f"openai/{model_id}" - api_base = ensure_v1(base_url) - if api_base: - kwargs["api_base"] = api_base - else: - provider_key = (native_provider or "").upper() - prefix = NATIVE_PROVIDER_PREFIX.get(provider_key, provider_key.lower()) - if prefix == "custom": - custom_provider = extra.get("custom_provider") or native_provider - model_string = f"{custom_provider}/{model_id}" if custom_provider else model_id - else: - model_string = f"{prefix}/{model_id}" - - api_base = resolve_api_base( - provider=provider_key, - provider_prefix=prefix, - config_api_base=base_url, - ) - if api_base: - kwargs["api_base"] = api_base + model_string = f"{litellm_provider}/{model_id}" if litellm_provider else model_id + api_base = _execution_api_base(protocol, base_url) + if api_base: + kwargs["api_base"] = api_base if api_version := extra.get("api_version"): kwargs["api_version"] = api_version @@ -126,18 +83,21 @@ def to_litellm( def native_connection_from_config(config: Mapping[str, Any]) -> dict[str, Any]: - """Build an in-memory NATIVE connection mapping from a legacy/global config.""" - provider = str(config.get("provider") or config.get("custom_provider") or "CUSTOM") + """Build an in-memory connection mapping from a global config.""" + protocol = str(config.get("protocol") or PROTOCOL_OPENAI_COMPATIBLE) + litellm_provider = str( + config.get("litellm_provider") + or config.get("custom_provider") + or default_litellm_provider(protocol) + ) extra: dict[str, Any] = { "litellm_params": config.get("litellm_params") or {}, } if config.get("api_version"): extra["api_version"] = config.get("api_version") - if config.get("custom_provider"): - extra["custom_provider"] = config.get("custom_provider") return { - "protocol": PROTOCOL_NATIVE, - "native_provider": provider, + "protocol": protocol, + "litellm_provider": litellm_provider, "base_url": config.get("api_base") or None, "api_key": config.get("api_key") or None, "extra": extra, @@ -145,7 +105,7 @@ def native_connection_from_config(config: Mapping[str, Any]) -> dict[str, Any]: __all__ = [ - "NATIVE_PROVIDER_PREFIX", + "default_litellm_provider", "ensure_v1", "native_connection_from_config", "to_litellm", diff --git a/surfsense_backend/app/services/vision_llm_router_service.py b/surfsense_backend/app/services/vision_llm_router_service.py index 0c7182ecf..0ff716324 100644 --- a/surfsense_backend/app/services/vision_llm_router_service.py +++ b/surfsense_backend/app/services/vision_llm_router_service.py @@ -3,19 +3,12 @@ from typing import Any from litellm import Router -from app.services.model_resolver import ( - NATIVE_PROVIDER_PREFIX, - native_connection_from_config, - to_litellm, -) +from app.services.model_resolver import native_connection_from_config, to_litellm logger = logging.getLogger(__name__) VISION_AUTO_MODE_ID = 0 -VISION_PROVIDER_MAP = NATIVE_PROVIDER_PREFIX - - class VisionLLMRouterService: _instance = None _router: Router | None = None @@ -141,12 +134,11 @@ def is_vision_auto_mode(config_id: int | None) -> bool: def build_vision_model_string( - provider: str, model_name: str, custom_provider: str | None + litellm_provider: str, model_name: str, custom_provider: str | None ) -> str: if custom_provider: return f"{custom_provider}/{model_name}" - prefix = VISION_PROVIDER_MAP.get(provider.upper(), provider.lower()) - return f"{prefix}/{model_name}" + return f"{litellm_provider}/{model_name}" def get_global_vision_llm_config(config_id: int) -> dict | None: diff --git a/surfsense_backend/app/services/vision_model_list_service.py b/surfsense_backend/app/services/vision_model_list_service.py index fc459910b..6eae8c455 100644 --- a/surfsense_backend/app/services/vision_model_list_service.py +++ b/surfsense_backend/app/services/vision_model_list_service.py @@ -97,16 +97,16 @@ def _process_vision_models(raw_models: list[dict]) -> list[dict]: } ) - native_provider = OPENROUTER_SLUG_TO_VISION_PROVIDER.get(provider_slug) - if native_provider: - if native_provider == "GOOGLE" and not model_name.startswith("gemini-"): + direct_provider = OPENROUTER_SLUG_TO_VISION_PROVIDER.get(provider_slug) + if direct_provider: + if direct_provider == "GOOGLE" and not model_name.startswith("gemini-"): continue processed.append( { "value": model_name, "label": name, - "provider": native_provider, + "provider": direct_provider, "context_window": context_window, } ) diff --git a/surfsense_backend/app/tasks/chat/streaming/flows/new_chat/llm_capability.py b/surfsense_backend/app/tasks/chat/streaming/flows/new_chat/llm_capability.py index 69b9f4ab8..f6fcf75d7 100644 --- a/surfsense_backend/app/tasks/chat/streaming/flows/new_chat/llm_capability.py +++ b/surfsense_backend/app/tasks/chat/streaming/flows/new_chat/llm_capability.py @@ -40,7 +40,7 @@ def check_image_input_capability( else None ) if not is_known_text_only_chat_model( - provider=agent_config.provider, + litellm_provider=agent_config.provider, model_name=agent_config.model_name, base_model=agent_base_model, custom_provider=agent_config.custom_provider, diff --git a/surfsense_backend/app/tasks/chat/streaming/flows/new_chat/title_gen.py b/surfsense_backend/app/tasks/chat/streaming/flows/new_chat/title_gen.py index fe3d210bb..d5e8c3729 100644 --- a/surfsense_backend/app/tasks/chat/streaming/flows/new_chat/title_gen.py +++ b/surfsense_backend/app/tasks/chat/streaming/flows/new_chat/title_gen.py @@ -80,7 +80,6 @@ async def _generate_title( from litellm import acompletion from app.services.llm_router_service import LLMRouterService - from app.services.provider_api_base import resolve_api_base from app.services.token_tracking_service import _turn_accumulator # Excludes this turn's own assistant row (pre-written by @@ -125,26 +124,12 @@ async def _generate_title( router = LLMRouterService.get_router() response = await router.acompletion(model="auto", messages=messages) else: - # Apply the same ``api_base`` cascade chat / vision / image-gen - # call sites use so we never inherit ``litellm.api_base`` - # (commonly set by ``AZURE_OPENAI_ENDPOINT``) when the chat - # config itself ships an empty ``api_base``. Without this the - # title-gen on an OpenRouter chat config would 404 against the - # inherited Azure endpoint — see ``provider_api_base`` for the - # same bug repro on the image-gen / vision paths. raw_model = getattr(llm, "model", "") or "" - provider_prefix = raw_model.split("/", 1)[0] if "/" in raw_model else None - provider_value = agent_config.provider if agent_config is not None else None - title_api_base = resolve_api_base( - provider=provider_value, - provider_prefix=provider_prefix, - config_api_base=getattr(llm, "api_base", None), - ) response = await acompletion( model=raw_model, messages=messages, api_key=getattr(llm, "api_key", None), - api_base=title_api_base, + api_base=getattr(llm, "api_base", None), ) usage_info = None diff --git a/surfsense_backend/app/tasks/chat/streaming/flows/shared/llm_bundle.py b/surfsense_backend/app/tasks/chat/streaming/flows/shared/llm_bundle.py index 7e2bc950b..f6870f5fa 100644 --- a/surfsense_backend/app/tasks/chat/streaming/flows/shared/llm_bundle.py +++ b/surfsense_backend/app/tasks/chat/streaming/flows/shared/llm_bundle.py @@ -1,8 +1,8 @@ """Load an LLM + AgentConfig bundle for a given config id. Handles both code paths uniformly: -- ``config_id >= 0`` → database-backed ``NewLLMConfig`` row (per-user/per-space). -- ``config_id < 0`` → YAML-defined global LLM config (built-in defaults). +- ``config_id > 0`` → database-backed model-connection ``Model`` row. +- ``config_id < 0`` → virtual global model materialized from YAML/OpenRouter. Returns ``(llm, agent_config, error_message)``; on success ``error_message`` is ``None``. The caller emits the friendly SSE error frame. @@ -12,15 +12,72 @@ from __future__ import annotations from typing import Any +from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import selectinload from app.agents.chat.runtime.llm_config import ( AgentConfig, - create_chat_litellm_from_agent_config, - create_chat_litellm_from_config, - load_agent_config, - load_global_llm_config_by_id, + SanitizedChatLiteLLM, ) +from app.config import config +from app.db import Model, SearchSpace +from app.services.model_resolver import to_litellm + + +def _agent_config_from_resolved( + *, + config_id: int, + config_name: str | None, + provider: str, + model_name: str, + api_key: str | None, + api_base: str | None, + litellm_params: dict | None, + supports_image_input: bool, + billing_tier: str = "free", +) -> AgentConfig: + return AgentConfig( + provider=provider, + model_name=model_name, + api_key=api_key or "", + api_base=api_base, + custom_provider=None, + litellm_params=litellm_params, + config_id=config_id, + config_name=config_name, + is_auto_mode=False, + billing_tier=billing_tier, + is_premium=billing_tier == "premium", + supports_image_input=supports_image_input, + ) + + +async def _load_search_space(session: AsyncSession, search_space_id: int) -> SearchSpace | None: + result = await session.execute(select(SearchSpace).where(SearchSpace.id == search_space_id)) + return result.scalars().first() + + +async def _load_db_model( + session: AsyncSession, + *, + model_id: int, + search_space: SearchSpace, +) -> Model | None: + result = await session.execute( + select(Model) + .options(selectinload(Model.connection)) + .where(Model.id == model_id, Model.enabled.is_(True)) + ) + model = result.scalars().first() + if not model or not model.connection or not model.connection.enabled: + return None + conn = model.connection + if conn.search_space_id is not None and conn.search_space_id != search_space.id: + return None + if conn.user_id is not None and conn.user_id != search_space.user_id: + return None + return model async def load_llm_bundle( @@ -29,29 +86,67 @@ async def load_llm_bundle( config_id: int, search_space_id: int, ) -> tuple[Any, AgentConfig | None, str | None]: - if config_id >= 0: - loaded_agent_config = await load_agent_config( - session=session, - config_id=config_id, - search_space_id=search_space_id, + search_space = await _load_search_space(session, search_space_id) + if not search_space: + return None, None, f"Search space {search_space_id} not found" + + if config_id > 0: + model = await _load_db_model( + session, + model_id=config_id, + search_space=search_space, ) - if not loaded_agent_config: + if not model or not (model.capabilities or {}).get("chat"): return ( None, None, - f"Failed to load NewLLMConfig with id {config_id}", + f"Failed to load chat model with id {config_id}", ) + model_string, litellm_kwargs = to_litellm(model.connection, model.model_id) + agent_config = _agent_config_from_resolved( + config_id=config_id, + config_name=model.display_name or model.model_id, + provider=model.connection.litellm_provider or "", + model_name=model.model_id, + api_key=model.connection.api_key, + api_base=model.connection.base_url, + litellm_params=(model.connection.extra or {}).get("litellm_params"), + supports_image_input=bool((model.capabilities or {}).get("vision")), + billing_tier="free", + ) return ( - create_chat_litellm_from_agent_config(loaded_agent_config), - loaded_agent_config, + SanitizedChatLiteLLM(model=model_string, **litellm_kwargs), + agent_config, None, ) - loaded_llm_config = load_global_llm_config_by_id(config_id) - if not loaded_llm_config: - return None, None, f"Failed to load LLM config with id {config_id}" - return ( - create_chat_litellm_from_config(loaded_llm_config), - AgentConfig.from_yaml_config(loaded_llm_config), + global_model = next((m for m in config.GLOBAL_MODELS if m.get("id") == config_id), None) + if not global_model or not (global_model.get("capabilities") or {}).get("chat"): + return None, None, f"Failed to load global chat model with id {config_id}" + global_connection = next( + ( + c + for c in config.GLOBAL_CONNECTIONS + if c.get("id") == global_model.get("connection_id") + ), + None, + ) + if not global_connection: + return None, None, f"Failed to load global connection for model {config_id}" + model_string, litellm_kwargs = to_litellm(global_connection, global_model["model_id"]) + agent_config = _agent_config_from_resolved( + config_id=config_id, + config_name=global_model.get("display_name") or global_model.get("model_id"), + provider=global_connection.get("litellm_provider") or "", + model_name=global_model["model_id"], + api_key=global_connection.get("api_key"), + api_base=global_connection.get("base_url"), + litellm_params=(global_connection.get("extra") or {}).get("litellm_params"), + supports_image_input=bool((global_model.get("capabilities") or {}).get("vision")), + billing_tier=str(global_model.get("billing_tier", "free")).lower(), + ) + return ( + SanitizedChatLiteLLM(model=model_string, **litellm_kwargs), + agent_config, None, ) diff --git a/surfsense_backend/tests/unit/services/test_auto_model_pin_service.py b/surfsense_backend/tests/unit/services/test_auto_model_pin_service.py index d1af29aeb..0af41a7ee 100644 --- a/surfsense_backend/tests/unit/services/test_auto_model_pin_service.py +++ b/surfsense_backend/tests/unit/services/test_auto_model_pin_service.py @@ -75,10 +75,10 @@ async def test_auto_first_turn_pins_one_model(monkeypatch): config, "GLOBAL_LLM_CONFIGS", [ - {"id": -2, "provider": "OPENAI", "model_name": "gpt-free", "api_key": "k1"}, + {"id": -2, "litellm_provider": "openai", "model_name": "gpt-free", "api_key": "k1"}, { "id": -1, - "provider": "OPENAI", + "litellm_provider": "openai", "model_name": "gpt-prem", "api_key": "k2", "billing_tier": "premium", @@ -117,7 +117,7 @@ async def test_premium_eligible_auto_prefers_premium_over_free(monkeypatch): [ { "id": -2, - "provider": "OPENAI", + "litellm_provider": "openai", "model_name": "gpt-free", "api_key": "k1", "billing_tier": "free", @@ -125,7 +125,7 @@ async def test_premium_eligible_auto_prefers_premium_over_free(monkeypatch): }, { "id": -1, - "provider": "OPENAI", + "litellm_provider": "openai", "model_name": "gpt-prem", "api_key": "k2", "billing_tier": "premium", @@ -164,7 +164,7 @@ async def test_premium_eligible_auto_prefers_azure_gpt_5_4(monkeypatch): [ { "id": -1, - "provider": "AZURE_OPENAI", + "litellm_provider": "azure", "model_name": "gpt-5.1", "api_key": "k1", "billing_tier": "premium", @@ -173,7 +173,7 @@ async def test_premium_eligible_auto_prefers_azure_gpt_5_4(monkeypatch): }, { "id": -2, - "provider": "AZURE_OPENAI", + "litellm_provider": "azure", "model_name": "gpt-5.4", "api_key": "k2", "billing_tier": "premium", @@ -182,7 +182,7 @@ async def test_premium_eligible_auto_prefers_azure_gpt_5_4(monkeypatch): }, { "id": -3, - "provider": "OPENROUTER", + "litellm_provider": "openrouter", "model_name": "openai/gpt-5.4", "api_key": "k3", "billing_tier": "premium", @@ -222,7 +222,7 @@ async def test_next_turn_reuses_existing_pin(monkeypatch): [ { "id": -1, - "provider": "OPENAI", + "litellm_provider": "openai", "model_name": "gpt-prem", "api_key": "k2", "billing_tier": "premium", @@ -263,7 +263,7 @@ async def test_premium_eligible_auto_can_pin_premium(monkeypatch): [ { "id": -1, - "provider": "OPENAI", + "litellm_provider": "openai", "model_name": "gpt-prem", "api_key": "k2", "billing_tier": "premium", @@ -301,14 +301,14 @@ async def test_premium_ineligible_auto_pins_free_only(monkeypatch): [ { "id": -2, - "provider": "OPENAI", + "litellm_provider": "openai", "model_name": "gpt-free", "api_key": "k1", "billing_tier": "free", }, { "id": -1, - "provider": "OPENAI", + "litellm_provider": "openai", "model_name": "gpt-prem", "api_key": "k2", "billing_tier": "premium", @@ -346,14 +346,14 @@ async def test_pinned_premium_stays_premium_after_quota_exhaustion(monkeypatch): [ { "id": -2, - "provider": "OPENAI", + "litellm_provider": "openai", "model_name": "gpt-free", "api_key": "k1", "billing_tier": "free", }, { "id": -1, - "provider": "OPENAI", + "litellm_provider": "openai", "model_name": "gpt-prem", "api_key": "k2", "billing_tier": "premium", @@ -391,14 +391,14 @@ async def test_force_repin_free_switches_auto_premium_pin_to_free(monkeypatch): [ { "id": -2, - "provider": "OPENAI", + "litellm_provider": "openai", "model_name": "gpt-free", "api_key": "k1", "billing_tier": "free", }, { "id": -1, - "provider": "OPENAI", + "litellm_provider": "openai", "model_name": "gpt-prem", "api_key": "k2", "billing_tier": "premium", @@ -437,7 +437,7 @@ async def test_explicit_user_model_change_clears_pin(monkeypatch): config, "GLOBAL_LLM_CONFIGS", [ - {"id": -2, "provider": "OPENAI", "model_name": "gpt-free", "api_key": "k1"}, + {"id": -2, "litellm_provider": "openai", "model_name": "gpt-free", "api_key": "k1"}, ], ) @@ -462,7 +462,7 @@ async def test_invalid_pinned_config_repairs_with_new_pin(monkeypatch): config, "GLOBAL_LLM_CONFIGS", [ - {"id": -2, "provider": "OPENAI", "model_name": "gpt-free", "api_key": "k1"}, + {"id": -2, "litellm_provider": "openai", "model_name": "gpt-free", "api_key": "k1"}, ], ) @@ -504,7 +504,7 @@ async def test_health_gated_config_is_excluded_from_selection(monkeypatch): [ { "id": -1, - "provider": "OPENROUTER", + "litellm_provider": "openrouter", "model_name": "venice/dead-model", "api_key": "k1", "billing_tier": "free", @@ -514,7 +514,7 @@ async def test_health_gated_config_is_excluded_from_selection(monkeypatch): }, { "id": -2, - "provider": "OPENROUTER", + "litellm_provider": "openrouter", "model_name": "google/gemini-flash", "api_key": "k1", "billing_tier": "free", @@ -556,7 +556,7 @@ async def test_tier_a_locks_first_premium_user_skips_or(monkeypatch): [ { "id": -1, - "provider": "AZURE_OPENAI", + "litellm_provider": "azure", "model_name": "gpt-5", "api_key": "k-yaml", "billing_tier": "premium", @@ -566,7 +566,7 @@ async def test_tier_a_locks_first_premium_user_skips_or(monkeypatch): }, { "id": -2, - "provider": "OPENROUTER", + "litellm_provider": "openrouter", "model_name": "openai/gpt-5", "api_key": "k-or", "billing_tier": "premium", @@ -608,7 +608,7 @@ async def test_tier_a_falls_through_to_or_when_a_pool_empty_for_user(monkeypatch [ { "id": -1, - "provider": "AZURE_OPENAI", + "litellm_provider": "azure", "model_name": "gpt-5", "api_key": "k-yaml", "billing_tier": "premium", @@ -618,7 +618,7 @@ async def test_tier_a_falls_through_to_or_when_a_pool_empty_for_user(monkeypatch }, { "id": -2, - "provider": "OPENROUTER", + "litellm_provider": "openrouter", "model_name": "google/gemini-flash:free", "api_key": "k-or", "billing_tier": "free", @@ -656,7 +656,7 @@ async def test_top_k_picks_only_high_score_models(monkeypatch): high_score_cfgs = [ { "id": -i, - "provider": "AZURE_OPENAI", + "litellm_provider": "azure", "model_name": f"gpt-x-{i}", "api_key": "k", "billing_tier": "premium", @@ -668,7 +668,7 @@ async def test_top_k_picks_only_high_score_models(monkeypatch): ] low_score_trap = { "id": -99, - "provider": "AZURE_OPENAI", + "litellm_provider": "azure", "model_name": "tiny-legacy", "api_key": "k", "billing_tier": "premium", @@ -729,7 +729,7 @@ async def test_pin_reuse_survives_health_gating_for_existing_pin(monkeypatch): [ { "id": -1, - "provider": "OPENROUTER", + "litellm_provider": "openrouter", "model_name": "venice/dead-model", "api_key": "k", "billing_tier": "premium", @@ -739,7 +739,7 @@ async def test_pin_reuse_survives_health_gating_for_existing_pin(monkeypatch): }, { "id": -2, - "provider": "AZURE_OPENAI", + "litellm_provider": "azure", "model_name": "gpt-5", "api_key": "k", "billing_tier": "premium", @@ -781,7 +781,7 @@ async def test_pin_reuse_regression_existing_healthy_pin(monkeypatch): [ { "id": -1, - "provider": "AZURE_OPENAI", + "litellm_provider": "azure", "model_name": "gpt-5", "api_key": "k", "billing_tier": "premium", @@ -791,7 +791,7 @@ async def test_pin_reuse_regression_existing_healthy_pin(monkeypatch): }, { "id": -2, - "provider": "AZURE_OPENAI", + "litellm_provider": "azure", "model_name": "gpt-5-pro", "api_key": "k", "billing_tier": "premium", @@ -839,7 +839,7 @@ async def test_runtime_cooled_down_pin_is_not_reused(monkeypatch): [ { "id": -1, - "provider": "OPENROUTER", + "litellm_provider": "openrouter", "model_name": "google/gemma-4-26b-a4b-it:free", "api_key": "k", "billing_tier": "free", @@ -849,7 +849,7 @@ async def test_runtime_cooled_down_pin_is_not_reused(monkeypatch): }, { "id": -2, - "provider": "OPENROUTER", + "litellm_provider": "openrouter", "model_name": "google/gemini-2.5-flash:free", "api_key": "k", "billing_tier": "free", @@ -892,7 +892,7 @@ async def test_clearing_runtime_cooldown_restores_pin_reuse(monkeypatch): [ { "id": -1, - "provider": "OPENROUTER", + "litellm_provider": "openrouter", "model_name": "google/gemma-4-26b-a4b-it:free", "api_key": "k", "billing_tier": "free", @@ -937,7 +937,7 @@ async def test_auto_pin_repin_excludes_previous_config_on_runtime_retry(monkeypa [ { "id": -1, - "provider": "OPENROUTER", + "litellm_provider": "openrouter", "model_name": "google/gemma-4-26b-a4b-it:free", "api_key": "k", "billing_tier": "free", @@ -947,7 +947,7 @@ async def test_auto_pin_repin_excludes_previous_config_on_runtime_retry(monkeypa }, { "id": -2, - "provider": "OPENROUTER", + "litellm_provider": "openrouter", "model_name": "google/gemini-2.5-flash:free", "api_key": "k", "billing_tier": "free", diff --git a/surfsense_backend/tests/unit/services/test_auto_pin_image_aware.py b/surfsense_backend/tests/unit/services/test_auto_pin_image_aware.py index 0e19b80e4..e267d59ba 100644 --- a/surfsense_backend/tests/unit/services/test_auto_pin_image_aware.py +++ b/surfsense_backend/tests/unit/services/test_auto_pin_image_aware.py @@ -74,7 +74,7 @@ def _thread(*, pinned: int | None = None): def _vision_cfg(id_: int, *, tier: str = "free", quality: int = 80) -> dict: return { "id": id_, - "provider": "OPENAI", + "litellm_provider": "openai", "model_name": f"vision-{id_}", "api_key": "k", "billing_tier": tier, @@ -87,7 +87,7 @@ def _vision_cfg(id_: int, *, tier: str = "free", quality: int = 80) -> dict: def _text_only_cfg(id_: int, *, tier: str = "free", quality: int = 90) -> dict: return { "id": id_, - "provider": "OPENAI", + "litellm_provider": "openai", "model_name": f"text-{id_}", "api_key": "k", "billing_tier": tier, @@ -261,7 +261,7 @@ async def test_image_turn_unannotated_cfg_resolves_via_helper(monkeypatch): session = _FakeSession(_thread()) cfg_unannotated_vision = { "id": -2, - "provider": "OPENAI", + "litellm_provider": "openai", "model_name": "gpt-4o", # known vision model in LiteLLM map "api_key": "k", "billing_tier": "free", diff --git a/surfsense_backend/tests/unit/services/test_llm_router_pool_filter.py b/surfsense_backend/tests/unit/services/test_llm_router_pool_filter.py index c309ff881..efe906ac0 100644 --- a/surfsense_backend/tests/unit/services/test_llm_router_pool_filter.py +++ b/surfsense_backend/tests/unit/services/test_llm_router_pool_filter.py @@ -25,10 +25,10 @@ def _fake_yaml_config( return { "id": id, "name": f"yaml-{id}", - "provider": "OPENAI", + "litellm_provider": "openai", "model_name": model_name, "api_key": "sk-test", - "api_base": "", + "api_base": "https://api.openai.com/v1", "billing_tier": billing_tier, "rpm": 100, "tpm": 100_000, @@ -54,10 +54,10 @@ def _fake_openrouter_config( return { "id": id, "name": f"or-{id}", - "provider": "OPENROUTER", + "litellm_provider": "openrouter", "model_name": model_name, "api_key": "sk-or-test", - "api_base": "", + "api_base": "https://openrouter.ai/api/v1", "billing_tier": billing_tier, "rpm": 20 if billing_tier == "free" else 200, "tpm": 100_000 if billing_tier == "free" else 1_000_000, diff --git a/surfsense_backend/tests/unit/services/test_or_health_enrichment.py b/surfsense_backend/tests/unit/services/test_or_health_enrichment.py index 1c74aa928..b4b6618a4 100644 --- a/surfsense_backend/tests/unit/services/test_or_health_enrichment.py +++ b/surfsense_backend/tests/unit/services/test_or_health_enrichment.py @@ -25,7 +25,7 @@ def _or_cfg( ) -> dict: return { "id": cid, - "provider": "OPENROUTER", + "litellm_provider": "openrouter", "model_name": model_name, "billing_tier": tier, "auto_pin_tier": "B" if tier == "premium" else "C", @@ -144,7 +144,7 @@ async def test_enrich_health_only_touches_or_provider(monkeypatch): """YAML cfgs that aren't OPENROUTER must be skipped entirely.""" yaml_cfg = { "id": -1, - "provider": "AZURE_OPENAI", + "litellm_provider": "azure", "model_name": "gpt-5", "billing_tier": "premium", "auto_pin_tier": "A", @@ -313,7 +313,7 @@ async def test_enrich_health_no_or_cfgs_is_noop(monkeypatch): """When the catalogue has no OR cfgs at all, no HTTP calls fire.""" yaml_cfg: dict[str, Any] = { "id": -1, - "provider": "AZURE_OPENAI", + "litellm_provider": "azure", "model_name": "gpt-5", "billing_tier": "premium", } diff --git a/surfsense_backend/tests/unit/tasks/test_stream_new_chat_image_safety_net.py b/surfsense_backend/tests/unit/tasks/test_stream_new_chat_image_safety_net.py index 792d059b0..6bfc72bf3 100644 --- a/surfsense_backend/tests/unit/tasks/test_stream_new_chat_image_safety_net.py +++ b/surfsense_backend/tests/unit/tasks/test_stream_new_chat_image_safety_net.py @@ -35,7 +35,7 @@ def test_safety_net_does_not_fire_for_azure_gpt_4o(): it text-only.""" assert ( is_known_text_only_chat_model( - provider="AZURE_OPENAI", + litellm_provider="azure", model_name="my-azure-deployment", base_model="gpt-4o", ) @@ -49,7 +49,7 @@ def test_safety_net_does_not_fire_for_unknown_model(): LiteLLM doesn't know about must flow through to the provider.""" assert ( is_known_text_only_chat_model( - provider="CUSTOM", + litellm_provider="custom", custom_provider="brand_new_proxy", model_name="brand-new-model-x9", ) @@ -69,7 +69,7 @@ def test_safety_net_does_not_fire_when_lookup_raises(monkeypatch): assert ( is_known_text_only_chat_model( - provider="OPENAI", + litellm_provider="openai", model_name="gpt-4o", ) is False @@ -88,7 +88,7 @@ def test_safety_net_fires_only_on_explicit_false(monkeypatch): monkeypatch.setattr(pc.litellm, "get_model_info", _info_explicit_false) assert ( is_known_text_only_chat_model( - provider="OPENAI", + litellm_provider="openai", model_name="text-only-stub", ) is True @@ -100,7 +100,7 @@ def test_safety_net_fires_only_on_explicit_false(monkeypatch): monkeypatch.setattr(pc.litellm, "get_model_info", _info_true) assert ( is_known_text_only_chat_model( - provider="OPENAI", + litellm_provider="openai", model_name="vision-stub", ) is False @@ -112,7 +112,7 @@ def test_safety_net_fires_only_on_explicit_false(monkeypatch): monkeypatch.setattr(pc.litellm, "get_model_info", _info_missing) assert ( is_known_text_only_chat_model( - provider="OPENAI", + litellm_provider="openai", model_name="missing-key-stub", ) is False From 831ad23c6c8f72d4f731290240cb38c7d44dda31 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Thu, 11 Jun 2026 18:22:45 +0530 Subject: [PATCH 20/59] fix(chat): harden image generation model routing --- .../deliverables/tools/generate_image.py | 90 ++++++++++----- .../app/routes/image_generation_routes.py | 107 ++++++++++-------- .../app/services/image_gen_router_service.py | 12 +- .../scripts/verify_chat_image_capability.py | 35 ++---- .../tests/unit/routes/test_image_gen_quota.py | 6 +- .../test_image_gen_api_base_defense.py | 51 +++------ .../test_vision_llm_api_base_defense.py | 26 ++--- 7 files changed, 156 insertions(+), 171 deletions(-) diff --git a/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/deliverables/tools/generate_image.py b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/deliverables/tools/generate_image.py index dd980c51c..fda327750 100644 --- a/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/deliverables/tools/generate_image.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/deliverables/tools/generate_image.py @@ -23,20 +23,26 @@ from app.db import ( ) from app.services.image_gen_router_service import ( IMAGE_GEN_AUTO_MODE_ID, - ImageGenRouterService, is_image_gen_auto_mode, ) -from app.services.model_resolver import native_connection_from_config, to_litellm +from app.services.auto_model_pin_service import ( + auto_model_candidates, + choose_auto_model_candidate, +) +from app.services.model_resolver import to_litellm from app.utils.signed_image_urls import generate_image_token logger = logging.getLogger(__name__) -def _get_global_image_gen_config(config_id: int) -> dict | None: - """Get a global image gen config by negative ID.""" - for cfg in config.GLOBAL_IMAGE_GEN_CONFIGS: - if cfg.get("id") == config_id: - return cfg - return None +def _get_global_model(model_id: int) -> dict | None: + return next((m for m in config.GLOBAL_MODELS if m.get("id") == model_id), None) + + +def _get_global_connection(connection_id: int) -> dict | None: + return next( + (c for c in config.GLOBAL_CONNECTIONS if c.get("id") == connection_id), + None, + ) def create_generate_image_tool( @@ -93,6 +99,16 @@ def create_generate_image_tool( # task's session is shared across every tool; without isolation, # autoflushes from a concurrent writer poison this tool too. async with shielded_async_session() as session: + result = await session.execute( + select(SearchSpace).filter(SearchSpace.id == search_space_id) + ) + search_space = result.scalars().first() + if not search_space: + return _failed( + {"error": "Search space not found"}, + error="Search space not found", + ) + if image_gen_model_id_override is not None: # Automation run: use the captured image model, insulated from # later search-space changes. No search-space read needed. @@ -100,16 +116,6 @@ def create_generate_image_tool( image_gen_model_id_override or IMAGE_GEN_AUTO_MODE_ID ) else: - result = await session.execute( - select(SearchSpace).filter(SearchSpace.id == search_space_id) - ) - search_space = result.scalars().first() - if not search_space: - return _failed( - {"error": "Search space not found"}, - error="Search space not found", - ) - config_id = ( search_space.image_gen_model_id or IMAGE_GEN_AUTO_MODE_ID @@ -122,24 +128,39 @@ def create_generate_image_tool( gen_kwargs["n"] = n if is_image_gen_auto_mode(config_id): - if not ImageGenRouterService.is_initialized(): + candidates = await auto_model_candidates( + session, + search_space_id=search_space_id, + user_id=search_space.user_id, + capability="image_gen", + ) + if not candidates: err = ( - "No image generation models configured. " + "No image generation models available. " "Please add an image model in Settings > Image Models." ) return _failed({"error": err}, error=err) - response = await ImageGenRouterService.aimage_generation( - prompt=prompt, model="auto", **gen_kwargs + config_id = int( + choose_auto_model_candidate(candidates, search_space_id)["id"] ) - elif config_id < 0: - cfg = _get_global_image_gen_config(config_id) - if not cfg: - err = f"Image generation config {config_id} not found" + + if config_id < 0: + global_model = _get_global_model(config_id) + if not global_model or not ( + global_model.get("capabilities") or {} + ).get("image_gen"): + err = f"Image generation model {config_id} not found" + return _failed({"error": err}, error=err) + global_connection = _get_global_connection( + global_model["connection_id"] + ) + if not global_connection: + err = f"Image generation connection for model {config_id} not found" return _failed({"error": err}, error=err) model_string, resolved_kwargs = to_litellm( - native_connection_from_config(cfg), - cfg["model_name"], + global_connection, + global_model["model_id"], ) gen_kwargs.update(resolved_kwargs) @@ -157,6 +178,19 @@ def create_generate_image_tool( if not db_model or not db_model.connection or not db_model.connection.enabled: err = f"Image generation model {config_id} not found" return _failed({"error": err}, error=err) + conn = db_model.connection + if ( + conn.search_space_id is not None + and conn.search_space_id != search_space_id + ): + err = f"Image generation model {config_id} not found" + return _failed({"error": err}, error=err) + if ( + conn.user_id is not None + and conn.user_id != search_space.user_id + ): + err = f"Image generation model {config_id} not found" + return _failed({"error": err}, error=err) if not (db_model.capabilities or {}).get("image_gen"): err = f"Model {config_id} is not image-generation capable" return _failed({"error": err}, error=err) diff --git a/surfsense_backend/app/routes/image_generation_routes.py b/surfsense_backend/app/routes/image_generation_routes.py index 0de368d57..5be1cedf1 100644 --- a/surfsense_backend/app/routes/image_generation_routes.py +++ b/surfsense_backend/app/routes/image_generation_routes.py @@ -45,10 +45,13 @@ from app.services.billable_calls import ( ) from app.services.image_gen_router_service import ( IMAGE_GEN_AUTO_MODE_ID, - ImageGenRouterService, is_image_gen_auto_mode, ) -from app.services.model_resolver import native_connection_from_config, to_litellm +from app.services.auto_model_pin_service import ( + auto_model_candidates, + choose_auto_model_candidate, +) +from app.services.model_resolver import to_litellm from app.users import current_active_user from app.utils.rbac import check_permission from app.utils.signed_image_urls import verify_image_token @@ -56,22 +59,15 @@ from app.utils.signed_image_urls import verify_image_token router = APIRouter() logger = logging.getLogger(__name__) -def _get_global_image_gen_config(config_id: int) -> dict | None: - """Get a global image generation configuration by ID (negative IDs).""" - if config_id == IMAGE_GEN_AUTO_MODE_ID: - return { - "id": IMAGE_GEN_AUTO_MODE_ID, - "name": "Auto (Fastest)", - "provider": "AUTO", - "model_name": "auto", - "is_auto_mode": True, - } - if config_id > 0: - return None - for cfg in config.GLOBAL_IMAGE_GEN_CONFIGS: - if cfg.get("id") == config_id: - return cfg - return None +def _get_global_model(model_id: int) -> dict | None: + return next((m for m in config.GLOBAL_MODELS if m.get("id") == model_id), None) + + +def _get_global_connection(connection_id: int) -> dict | None: + return next( + (c for c in config.GLOBAL_CONNECTIONS if c.get("id") == connection_id), + None, + ) async def _resolve_billing_for_image_gen( @@ -87,30 +83,41 @@ async def _resolve_billing_for_image_gen( config that will actually run, and so we don't open an ``ImageGeneration`` row for a request that's about to 402. - User-owned (positive ID) BYOK configs are always free — they cost - the user nothing on our side. Auto mode currently treats as free - because the underlying router can dispatch to either premium or - free YAML configs and we don't surface the resolved deployment up - here yet. Bringing Auto under premium billing would require - threading the chosen deployment back from ``ImageGenRouterService``. + User-owned (positive ID) BYOK models are always free — they cost + the user nothing on our side. Auto mode resolves to one concrete + global or BYOK model before billing is calculated. """ resolved_id = config_id if resolved_id is None: resolved_id = search_space.image_gen_model_id or IMAGE_GEN_AUTO_MODE_ID if is_image_gen_auto_mode(resolved_id): - return ("free", "auto", DEFAULT_IMAGE_RESERVE_MICROS) + candidates = await auto_model_candidates( + session, + search_space_id=search_space.id, + user_id=search_space.user_id, + capability="image_gen", + ) + if not candidates: + return ("free", "auto", DEFAULT_IMAGE_RESERVE_MICROS) + selected = choose_auto_model_candidate(candidates, search_space.id) + resolved_id = int(selected["id"]) if resolved_id < 0: - cfg = _get_global_image_gen_config(resolved_id) or {} - billing_tier = str(cfg.get("billing_tier", "free")).lower() - base_model, _ = to_litellm(native_connection_from_config(cfg), cfg.get("model_name", "")) + global_model = _get_global_model(resolved_id) or {} + global_connection = _get_global_connection(global_model.get("connection_id", 0)) + billing_tier = str(global_model.get("billing_tier", "free")).lower() + if global_connection and global_model.get("model_id"): + base_model, _ = to_litellm(global_connection, global_model["model_id"]) + else: + base_model = "global_image_model" + catalog = global_model.get("catalog") or {} reserve_micros = int( - cfg.get("quota_reserve_micros") or DEFAULT_IMAGE_RESERVE_MICROS + catalog.get("quota_reserve_micros") or DEFAULT_IMAGE_RESERVE_MICROS ) return (billing_tier, base_model, reserve_micros) - # Positive ID = user-owned BYOK image-gen config — always free. + # Positive ID = user-owned BYOK image-gen model — always free. return ("free", "user_byok", DEFAULT_IMAGE_RESERVE_MICROS) @@ -146,23 +153,28 @@ async def _execute_image_generation( gen_kwargs["response_format"] = image_gen.response_format if is_image_gen_auto_mode(config_id): - if not ImageGenRouterService.is_initialized(): - raise ValueError( - "Auto mode requested but Image Generation Router not initialized. " - "Ensure global_llm_config.yaml has global_image_generation_configs." - ) - response = await ImageGenRouterService.aimage_generation( - prompt=image_gen.prompt, model="auto", **gen_kwargs + candidates = await auto_model_candidates( + session, + search_space_id=search_space.id, + user_id=search_space.user_id, + capability="image_gen", ) - elif config_id < 0: - # Global config from YAML - cfg = _get_global_image_gen_config(config_id) - if not cfg: - raise ValueError(f"Global image generation config {config_id} not found") + if not candidates: + raise ValueError("No image-generation models are available for Auto mode") + config_id = int(choose_auto_model_candidate(candidates, search_space.id)["id"]) + image_gen.image_generation_config_id = config_id + + if config_id < 0: + global_model = _get_global_model(config_id) + if not global_model or not (global_model.get("capabilities") or {}).get("image_gen"): + raise ValueError(f"Global image generation model {config_id} not found") + global_connection = _get_global_connection(global_model["connection_id"]) + if not global_connection: + raise ValueError(f"Global connection for image model {config_id} not found") model_string, resolved_kwargs = to_litellm( - native_connection_from_config(cfg), - cfg["model_name"], + global_connection, + global_model["model_id"], ) gen_kwargs.update(resolved_kwargs) @@ -183,6 +195,11 @@ async def _execute_image_generation( db_model = result.scalars().first() if not db_model or not db_model.connection or not db_model.connection.enabled: raise ValueError(f"Image generation model {config_id} not found") + conn = db_model.connection + if conn.search_space_id is not None and conn.search_space_id != search_space.id: + raise ValueError(f"Image generation model {config_id} not found") + if conn.user_id is not None and conn.user_id != search_space.user_id: + raise ValueError(f"Image generation model {config_id} not found") if not (db_model.capabilities or {}).get("image_gen"): raise ValueError(f"Model {config_id} is not image-generation capable") @@ -255,7 +272,7 @@ async def get_global_image_gen_configs( "id": cfg.get("id"), "name": cfg.get("name"), "description": cfg.get("description"), - "provider": cfg.get("provider"), + "provider": cfg.get("litellm_provider"), "custom_provider": cfg.get("custom_provider"), "model_name": cfg.get("model_name"), "api_base": cfg.get("api_base") or None, diff --git a/surfsense_backend/app/services/image_gen_router_service.py b/surfsense_backend/app/services/image_gen_router_service.py index 0b03f5c6d..3716537b9 100644 --- a/surfsense_backend/app/services/image_gen_router_service.py +++ b/surfsense_backend/app/services/image_gen_router_service.py @@ -20,23 +20,13 @@ from typing import Any from litellm import Router from litellm.utils import ImageResponse -from app.services.model_resolver import ( - NATIVE_PROVIDER_PREFIX, - native_connection_from_config, - to_litellm, -) +from app.services.model_resolver import native_connection_from_config, to_litellm logger = logging.getLogger(__name__) # Special ID for Auto mode - uses router for load balancing IMAGE_GEN_AUTO_MODE_ID = 0 -# Provider mapping for LiteLLM model string construction. -# Only includes providers that support image generation. -# See: https://docs.litellm.ai/docs/image_generation#supported-providers -IMAGE_GEN_PROVIDER_MAP = NATIVE_PROVIDER_PREFIX - - class ImageGenRouterService: """ Singleton service for managing LiteLLM Router for image generation. diff --git a/surfsense_backend/scripts/verify_chat_image_capability.py b/surfsense_backend/scripts/verify_chat_image_capability.py index a49d4eab2..6e711f99a 100644 --- a/surfsense_backend/scripts/verify_chat_image_capability.py +++ b/surfsense_backend/scripts/verify_chat_image_capability.py @@ -55,7 +55,6 @@ from app.services.openrouter_integration_service import ( # noqa: E402 _OPENROUTER_DYNAMIC_MARKER, OpenRouterIntegrationService, ) -from app.services.provider_api_base import resolve_api_base # noqa: E402 from app.services.provider_capabilities import ( # noqa: E402 derive_supports_image_input, is_known_text_only_chat_model, @@ -154,13 +153,13 @@ def _probe_chat_capability(cfg: dict) -> tuple[bool, str]: litellm_params.get("base_model") if isinstance(litellm_params, dict) else None ) cap = derive_supports_image_input( - provider=cfg.get("provider"), + litellm_provider=cfg.get("litellm_provider"), model_name=cfg.get("model_name"), base_model=base_model, custom_provider=cfg.get("custom_provider"), ) block = is_known_text_only_chat_model( - provider=cfg.get("provider"), + litellm_provider=cfg.get("litellm_provider"), model_name=cfg.get("model_name"), base_model=base_model, custom_provider=cfg.get("custom_provider"), @@ -179,11 +178,7 @@ def _probe_chat_capability(cfg: dict) -> tuple[bool, str]: def _build_chat_model_string(cfg: dict) -> str: if cfg.get("custom_provider"): return f"{cfg['custom_provider']}/{cfg['model_name']}" - from app.services.provider_capabilities import _PROVIDER_PREFIX_MAP - - prefix = _PROVIDER_PREFIX_MAP.get( - (cfg.get("provider") or "").upper(), (cfg.get("provider") or "").lower() - ) + prefix = cfg.get("litellm_provider") or "openai" return f"{prefix}/{cfg['model_name']}" @@ -195,11 +190,6 @@ def _build_chat_model_string(cfg: dict) -> str: async def _live_chat_image_call(cfg: dict) -> tuple[bool, str]: """Send a 1x1 PNG + `reply with one word: ok` to the chat config.""" model_string = _build_chat_model_string(cfg) - api_base = resolve_api_base( - provider=cfg.get("provider"), - provider_prefix=model_string.split("/", 1)[0], - config_api_base=cfg.get("api_base") or None, - ) kwargs: dict[str, Any] = { "model": model_string, "api_key": cfg.get("api_key"), @@ -218,8 +208,8 @@ async def _live_chat_image_call(cfg: dict) -> tuple[bool, str]: "max_tokens": 16, "timeout": 60, } - if api_base: - kwargs["api_base"] = api_base + if cfg.get("api_base"): + kwargs["api_base"] = cfg["api_base"] if cfg.get("litellm_params"): # Strip pricing keys — they're tracking-only and confuse some # provider validators (e.g. azure/openai reject unknown kwargs @@ -257,20 +247,11 @@ _IMAGE_GEN_PROMPTS: tuple[str, ...] = ( async def _live_image_gen_call(cfg: dict) -> tuple[bool, str]: """Generate one tiny image to verify the deployment is reachable.""" - from app.services.provider_capabilities import _PROVIDER_PREFIX_MAP - if cfg.get("custom_provider"): prefix = cfg["custom_provider"] else: - prefix = _PROVIDER_PREFIX_MAP.get( - (cfg.get("provider") or "").upper(), (cfg.get("provider") or "").lower() - ) + prefix = cfg.get("litellm_provider") or "openai" model_string = f"{prefix}/{cfg['model_name']}" - api_base = resolve_api_base( - provider=cfg.get("provider"), - provider_prefix=prefix, - config_api_base=cfg.get("api_base") or None, - ) base_kwargs: dict[str, Any] = { "model": model_string, "api_key": cfg.get("api_key"), @@ -278,8 +259,8 @@ async def _live_image_gen_call(cfg: dict) -> tuple[bool, str]: "size": "1024x1024", "timeout": 120, } - if api_base: - base_kwargs["api_base"] = api_base + if cfg.get("api_base"): + base_kwargs["api_base"] = cfg["api_base"] if cfg.get("api_version"): base_kwargs["api_version"] = cfg["api_version"] if cfg.get("litellm_params"): diff --git a/surfsense_backend/tests/unit/routes/test_image_gen_quota.py b/surfsense_backend/tests/unit/routes/test_image_gen_quota.py index 636b7de31..53c0f50a9 100644 --- a/surfsense_backend/tests/unit/routes/test_image_gen_quota.py +++ b/surfsense_backend/tests/unit/routes/test_image_gen_quota.py @@ -49,14 +49,14 @@ async def test_resolve_billing_for_premium_global_config(monkeypatch): [ { "id": -1, - "provider": "OPENAI", + "litellm_provider": "openai", "model_name": "gpt-image-1", "billing_tier": "premium", "quota_reserve_micros": 75_000, }, { "id": -2, - "provider": "OPENROUTER", + "litellm_provider": "openrouter", "model_name": "google/gemini-2.5-flash-image", "billing_tier": "free", }, @@ -118,7 +118,7 @@ async def test_resolve_billing_falls_back_to_search_space_default(monkeypatch): [ { "id": -7, - "provider": "OPENAI", + "litellm_provider": "openai", "model_name": "gpt-image-1", "billing_tier": "premium", } diff --git a/surfsense_backend/tests/unit/services/test_image_gen_api_base_defense.py b/surfsense_backend/tests/unit/services/test_image_gen_api_base_defense.py index 571e7d15b..63aa934a3 100644 --- a/surfsense_backend/tests/unit/services/test_image_gen_api_base_defense.py +++ b/surfsense_backend/tests/unit/services/test_image_gen_api_base_defense.py @@ -1,19 +1,4 @@ -"""Defense-in-depth: image-gen call sites must not let an empty -``api_base`` fall through to LiteLLM's module-global ``litellm.api_base``. - -The bug repro: an OpenRouter image-gen config ships -``api_base=""``. The pre-fix call site in -``image_generation_routes._execute_image_generation`` did -``if cfg.get("api_base"): kwargs["api_base"] = cfg["api_base"]`` which -silently dropped the empty string. LiteLLM then fell back to -``litellm.api_base`` (commonly inherited from ``AZURE_OPENAI_ENDPOINT``) -and OpenRouter's ``image_generation/transformation`` appended -``/chat/completions`` to it → 404 ``Resource not found``. - -This test pins the post-fix behaviour: with an empty ``api_base`` in -the config, the call site MUST set ``api_base`` to OpenRouter's public -URL instead of leaving it unset. -""" +"""Image-gen call sites must pass each config's explicit ``api_base``.""" from __future__ import annotations @@ -26,20 +11,17 @@ pytestmark = pytest.mark.unit @pytest.mark.asyncio -async def test_global_openrouter_image_gen_sets_api_base_when_config_empty(): - """The global-config branch (``config_id < 0``) of - ``_execute_image_generation`` must apply the resolver and pin - ``api_base`` to OpenRouter when the config ships an empty string. - """ +async def test_global_openrouter_image_gen_sets_explicit_api_base(): + """The global-config branch forwards the explicit OpenRouter base.""" from app.routes import image_generation_routes cfg = { "id": -20_001, "name": "GPT Image 1 (OpenRouter)", - "provider": "OPENROUTER", + "litellm_provider": "openrouter", "model_name": "openai/gpt-image-1", "api_key": "sk-or-test", - "api_base": "", # the original bug shape + "api_base": "https://openrouter.ai/api/v1", "api_version": None, "litellm_params": {}, } @@ -80,16 +62,13 @@ async def test_global_openrouter_image_gen_sets_api_base_when_config_empty(): session=session, image_gen=image_gen, search_space=search_space ) - # The whole point of the fix: even with empty ``api_base`` in the - # config, we forward OpenRouter's public URL so the call doesn't - # inherit an Azure endpoint. assert captured.get("api_base") == "https://openrouter.ai/api/v1" assert captured["model"] == "openrouter/openai/gpt-image-1" @pytest.mark.asyncio -async def test_generate_image_tool_global_sets_api_base_when_config_empty(): - """Same defense at the agent tool entry point — both surfaces share +async def test_generate_image_tool_global_sets_explicit_api_base(): + """Same explicit-base behavior at the agent tool entry point — both surfaces share the same OpenRouter config payloads.""" from app.agents.chat.multi_agent_chat.subagents.builtins.deliverables.tools import ( generate_image as gi_module, @@ -98,10 +77,10 @@ async def test_generate_image_tool_global_sets_api_base_when_config_empty(): cfg = { "id": -20_001, "name": "GPT Image 1 (OpenRouter)", - "provider": "OPENROUTER", + "litellm_provider": "openrouter", "model_name": "openai/gpt-image-1", "api_key": "sk-or-test", - "api_base": "", + "api_base": "https://openrouter.ai/api/v1", "api_version": None, "litellm_params": {}, } @@ -171,20 +150,16 @@ async def test_generate_image_tool_global_sets_api_base_when_config_empty(): assert captured["model"] == "openrouter/openai/gpt-image-1" -def test_image_gen_router_deployment_sets_api_base_when_config_empty(): - """The Auto-mode router pool must also resolve ``api_base`` when an - OpenRouter config ships an empty string. The deployment dict is fed - straight to ``litellm.Router``, so a missing ``api_base`` would - leak the same way as the direct call sites. - """ +def test_image_gen_router_deployment_sets_explicit_api_base(): + """The Auto-mode router pool carries explicit api_base into deployments.""" from app.services.image_gen_router_service import ImageGenRouterService deployment = ImageGenRouterService._config_to_deployment( { "model_name": "openai/gpt-image-1", - "provider": "OPENROUTER", + "litellm_provider": "openrouter", "api_key": "sk-or-test", - "api_base": "", + "api_base": "https://openrouter.ai/api/v1", } ) assert deployment is not None diff --git a/surfsense_backend/tests/unit/services/test_vision_llm_api_base_defense.py b/surfsense_backend/tests/unit/services/test_vision_llm_api_base_defense.py index 5e3aa6eda..48dfc8e0b 100644 --- a/surfsense_backend/tests/unit/services/test_vision_llm_api_base_defense.py +++ b/surfsense_backend/tests/unit/services/test_vision_llm_api_base_defense.py @@ -1,12 +1,4 @@ -"""Defense-in-depth: vision-LLM resolution must not leak ``api_base`` -defaults from ``litellm.api_base`` either. - -Vision shares the same shape as image-gen — global YAML / OpenRouter -dynamic configs ship ``api_base=""`` and the pre-fix ``get_vision_llm`` -call sites would silently drop the empty string and inherit -``AZURE_OPENAI_ENDPOINT``. ``ChatLiteLLM(...)`` doesn't 404 on -construction so we test the kwargs we hand to it instead. -""" +"""Vision LLM resolution must pass explicit per-config ``api_base``.""" from __future__ import annotations @@ -19,19 +11,16 @@ pytestmark = pytest.mark.unit @pytest.mark.asyncio async def test_get_vision_llm_global_openrouter_sets_api_base(): - """Global negative-ID branch: an OpenRouter vision config with - ``api_base=""`` must end up calling ``SanitizedChatLiteLLM`` with - ``api_base="https://openrouter.ai/api/v1"`` — never an empty string, - never silently absent.""" + """Global negative-ID branch forwards the explicit OpenRouter base.""" from app.services import llm_service cfg = { "id": -30_001, "name": "GPT-4o Vision (OpenRouter)", - "provider": "OPENROUTER", + "litellm_provider": "openrouter", "model_name": "openai/gpt-4o", "api_key": "sk-or-test", - "api_base": "", + "api_base": "https://openrouter.ai/api/v1", "api_version": None, "litellm_params": {}, "billing_tier": "free", @@ -72,16 +61,15 @@ async def test_get_vision_llm_global_openrouter_sets_api_base(): def test_vision_router_deployment_sets_api_base_when_config_empty(): - """Auto-mode vision router: deployments are fed to ``litellm.Router``, - so the resolver has to apply at deployment construction time too.""" + """Auto-mode vision router carries explicit api_base into deployments.""" from app.services.vision_llm_router_service import VisionLLMRouterService deployment = VisionLLMRouterService._config_to_deployment( { "model_name": "openai/gpt-4o", - "provider": "OPENROUTER", + "litellm_provider": "openrouter", "api_key": "sk-or-test", - "api_base": "", + "api_base": "https://openrouter.ai/api/v1", } ) assert deployment is not None From 3089dd4cb6b9ba1bc8132ef50967bd70b6a3f33b Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Thu, 11 Jun 2026 18:22:57 +0530 Subject: [PATCH 21/59] refactor(model-connections): simplify connection settings UI --- .../model-connections-mutation.atoms.ts | 12 +- .../components/new-chat/model-selector.tsx | 4 +- .../settings/model-connections-settings.tsx | 234 +++++++----------- .../types/model-connections.types.ts | 8 +- .../hooks/use-automation-eligible-models.ts | 2 +- .../lib/apis/model-connections-api.service.ts | 31 ++- 6 files changed, 126 insertions(+), 165 deletions(-) diff --git a/surfsense_web/atoms/model-connections/model-connections-mutation.atoms.ts b/surfsense_web/atoms/model-connections/model-connections-mutation.atoms.ts index 76289e60d..101bad1b5 100644 --- a/surfsense_web/atoms/model-connections/model-connections-mutation.atoms.ts +++ b/surfsense_web/atoms/model-connections/model-connections-mutation.atoms.ts @@ -4,8 +4,10 @@ import type { ConnectionCreateRequest, ConnectionUpdateRequest, ModelCreateRequest, + ModelRead, ModelRoles, ModelUpdateRequest, + VerifyConnectionResponse, } from "@/contracts/types/model-connections.types"; import { modelConnectionsApiService } from "@/lib/apis/model-connections-api.service"; import { cacheKeys } from "@/lib/query-client/cache-keys"; @@ -67,7 +69,7 @@ export const verifyModelConnectionMutationAtom = atomWithMutation((get) => { return { mutationKey: ["model-connections", "verify"], mutationFn: (id: number) => modelConnectionsApiService.verifyConnection(id), - onSuccess: (result) => { + onSuccess: (result: VerifyConnectionResponse) => { if (result.ok) { toast.success("Connection verified"); } else { @@ -90,11 +92,9 @@ export const discoverConnectionModelsMutationAtom = atomWithMutation((get) => { return { mutationKey: ["model-connections", "discover"], mutationFn: (id: number) => modelConnectionsApiService.discoverModels(id), - onSuccess: (models) => { + onSuccess: (models: ModelRead[]) => { toast.success( - models.length - ? `${models.length} models discovered` - : "No models found for this connection" + models.length ? `${models.length} models discovered` : "No models found for this connection" ); invalidateModelConnections(searchSpaceId); }, @@ -132,7 +132,7 @@ export const testModelMutationAtom = atomWithMutation((get) => { return { mutationKey: ["models", "test"], mutationFn: (id: number) => modelConnectionsApiService.testModel(id), - onSuccess: (result) => { + onSuccess: (result: VerifyConnectionResponse) => { if (result.ok) toast.success("Model test succeeded"); else toast.error(result.message || "Model test failed"); invalidateModelConnections(searchSpaceId); diff --git a/surfsense_web/components/new-chat/model-selector.tsx b/surfsense_web/components/new-chat/model-selector.tsx index 7c912afbb..4744da617 100644 --- a/surfsense_web/components/new-chat/model-selector.tsx +++ b/surfsense_web/components/new-chat/model-selector.tsx @@ -56,7 +56,7 @@ function modelName(model: ModelRead) { function connectionLabel(connection: ConnectionRead) { if (connection.scope === "GLOBAL") return "Hosted"; - return connection.native_provider || connection.protocol; + return connection.litellm_provider || connection.protocol; } function flattenChatModels(connections: ConnectionRead[]) { @@ -67,7 +67,7 @@ function flattenChatModels(connections: ConnectionRead[]) { ...model, connectionId: connection.id, connectionLabel: connectionLabel(connection), - provider: connection.native_provider || connection.protocol, + provider: connection.litellm_provider || connection.protocol, })) ); } diff --git a/surfsense_web/components/settings/model-connections-settings.tsx b/surfsense_web/components/settings/model-connections-settings.tsx index 59873408f..29501abda 100644 --- a/surfsense_web/components/settings/model-connections-settings.tsx +++ b/surfsense_web/components/settings/model-connections-settings.tsx @@ -2,7 +2,7 @@ import { useAtom, useAtomValue } from "jotai"; import { CheckCircle2, PlugZap, Plus, RefreshCcw, XCircle } from "lucide-react"; -import { useMemo, useState } from "react"; +import { useState } from "react"; import { addManualModelMutationAtom, createModelConnectionMutationAtom, @@ -35,60 +35,32 @@ import type { ConnectionRead, ModelRead, } from "@/contracts/types/model-connections.types"; -import { isCloud } from "@/lib/env-config"; import { getProviderIcon } from "@/lib/provider-icons"; -type Preset = { - id: string; - label: string; - protocol: ConnectionProtocol; - nativeProvider?: string; - baseUrl?: string; - local?: boolean; -}; - -const PRESETS: Preset[] = [ - { id: "custom", label: "OpenAI-compatible (any URL)", protocol: "OPENAI_COMPATIBLE" }, - { id: "openai", label: "OpenAI", protocol: "NATIVE", nativeProvider: "OPENAI" }, - { id: "anthropic", label: "Anthropic", protocol: "NATIVE", nativeProvider: "ANTHROPIC" }, - { id: "openrouter", label: "OpenRouter", protocol: "NATIVE", nativeProvider: "OPENROUTER" }, +const PROTOCOL_OPTIONS: { value: ConnectionProtocol; label: string; description: string }[] = [ { - id: "ollama", + value: "OPENAI_COMPATIBLE", + label: "OpenAI-compatible", + description: "Use for OpenAI, OpenRouter, Groq, vLLM, LM Studio, and compatible APIs.", + }, + { + value: "ANTHROPIC", + label: "Anthropic", + description: "Use for Claude endpoints that require Anthropic headers.", + }, + { + value: "OLLAMA", label: "Ollama", - protocol: "OLLAMA", - baseUrl: "http://host.docker.internal:11434", - local: true, - }, - { - id: "lmstudio", - label: "LM Studio", - protocol: "OPENAI_COMPATIBLE", - baseUrl: "http://host.docker.internal:1234/v1", - local: true, - }, - { - id: "llamacpp", - label: "llama.cpp", - protocol: "OPENAI_COMPATIBLE", - baseUrl: "http://host.docker.internal:8080/v1", - local: true, - }, - { - id: "localai", - label: "LocalAI", - protocol: "OPENAI_COMPATIBLE", - baseUrl: "http://host.docker.internal:8080/v1", - local: true, - }, - { - id: "vllm", - label: "vLLM", - protocol: "OPENAI_COMPATIBLE", - baseUrl: "http://host.docker.internal:8000/v1", - local: true, + description: "Use for Ollama's native API.", }, ]; +function defaultLitellmProvider(protocol: ConnectionProtocol) { + if (protocol === "OLLAMA") return "ollama_chat"; + if (protocol === "ANTHROPIC") return "anthropic"; + return "openai"; +} + // Free-text URL hints (datalist), mirroring OpenWebUI. These never restrict // what the user can type — any OpenAI-compatible endpoint works. const URL_SUGGESTIONS = [ @@ -135,9 +107,9 @@ function flattenModels(connections: ConnectionRead[]) { return connections.flatMap((connection) => connection.models.map((model) => ({ ...model, - connectionName: connection.native_provider || connection.protocol, + connectionName: connection.litellm_provider || connection.protocol, connectionId: connection.id, - provider: connection.native_provider || connection.protocol, + provider: connection.litellm_provider || connection.protocol, })) ); } @@ -156,7 +128,7 @@ function ConnectionCard({ connection }: { connection: ConnectionRead }) { const [allowlistText, setAllowlistText] = useState(allowlist.join(", ")); const [manualModelId, setManualModelId] = useState(""); - const providerLabel = connection.native_provider || connection.protocol; + const providerLabel = connection.litellm_provider || connection.protocol; const isLocal = connection.protocol === "OLLAMA" || !connection.base_url?.startsWith("https"); function saveAllowlist() { @@ -200,11 +172,7 @@ function ConnectionCard({ connection }: { connection: ConnectionRead }) { > Test -
@@ -212,8 +180,8 @@ function ConnectionCard({ connection }: { connection: ConnectionRead }) { {connection.last_status && connection.last_status !== "OK" ? (

- {connection.last_error || "Could not list models."} Chat may still work — add model - IDs manually below. + {connection.last_error || "Could not list models."} Chat may still work — add model IDs + manually below.

) : null} @@ -236,8 +204,8 @@ function ConnectionCard({ connection }: { connection: ConnectionRead }) {

- Leave empty to discover all models. Recommended for providers with large catalogs - (e.g. OpenRouter). + Leave empty to discover all models. Recommended for providers with large catalogs (e.g. + OpenRouter).

) : null} @@ -314,20 +282,14 @@ export function ModelConnectionsSettings({ searchSpaceId }: { searchSpaceId: num const createConnection = useAtomValue(createModelConnectionMutationAtom); const updateRoles = useAtomValue(updateModelRolesMutationAtom); - const visiblePresets = useMemo( - () => PRESETS.filter((preset) => !(isCloud() && preset.local)), - [] - ); - const [presetId, setPresetId] = useState(visiblePresets[0]?.id ?? "custom"); - const preset = visiblePresets.find((item) => item.id === presetId) ?? visiblePresets[0]; - const [baseUrl, setBaseUrl] = useState(preset?.baseUrl ?? ""); + const [protocol, setProtocol] = useState("OPENAI_COMPATIBLE"); + const [baseUrl, setBaseUrl] = useState(""); const [apiKey, setApiKey] = useState(""); - // Native providers carry their endpoint inside LiteLLM, so Base URL is hidden - // by default and only revealed for power users who want to override it. - const [showCustomEndpoint, setShowCustomEndpoint] = useState(false); - - const isNative = preset?.protocol === "NATIVE"; - const requiresUrl = !isNative; + const [litellmProvider, setLitellmProvider] = useState(""); + const [showAdvancedProvider, setShowAdvancedProvider] = useState(false); + const selectedProtocol = PROTOCOL_OPTIONS.find((item) => item.value === protocol); + const protocolDefaultProvider = defaultLitellmProvider(protocol); + const isOllama = protocol === "OLLAMA"; const allConnections = [...globalConnections, ...connections]; const enabledModels = flattenModels(allConnections).filter((model) => model.enabled); @@ -335,21 +297,12 @@ export function ModelConnectionsSettings({ searchSpaceId }: { searchSpaceId: num const visionModels = enabledModels.filter((model) => capability(model, "vision")); const imageModels = enabledModels.filter((model) => capability(model, "image_gen")); - function onPresetChange(value: string) { - setPresetId(value); - const next = visiblePresets.find((item) => item.id === value); - // Native providers use LiteLLM's built-in endpoint; everything else needs - // (and may prefill) a Base URL. - setBaseUrl(next?.protocol === "NATIVE" ? "" : (next?.baseUrl ?? "")); - setShowCustomEndpoint(false); - } - function handleCreate() { - if (!preset) return; + const explicitProvider = litellmProvider.trim(); createConnection.mutate( { - protocol: preset.protocol, - native_provider: preset.nativeProvider, + protocol, + litellm_provider: explicitProvider ? explicitProvider : null, base_url: baseUrl || null, api_key: apiKey || null, scope: "SEARCH_SPACE", @@ -384,90 +337,89 @@ export function ModelConnectionsSettings({ searchSpaceId }: { searchSpaceId: num
- - setProtocol(value as ConnectionProtocol)} + > - {visiblePresets.map((item) => ( - - - {getProviderIcon(item.nativeProvider || item.protocol, { - className: "size-4", - })} - {item.label} - + {PROTOCOL_OPTIONS.map((item) => ( + + {item.label} ))}
- - {isNative && !showCustomEndpoint ? ( -
-
- Uses provider default -
- -
- ) : ( - <> - setBaseUrl(event.target.value)} - placeholder="https://api.example.com/v1" - list="model-conn-url-suggestions" - /> - - {URL_SUGGESTIONS.map((url) => ( - - - )} + + setBaseUrl(event.target.value)} + placeholder={ + isOllama ? "http://host.docker.internal:11434" : "https://api.example.com/v1" + } + list="model-conn-url-suggestions" + /> + + {URL_SUGGESTIONS.map((url) => ( +
- + setApiKey(event.target.value)} - placeholder={preset?.local ? "Optional for local models" : "API key"} + placeholder={isOllama ? "Optional for Ollama" : "API key"} type="password" />
- {preset?.local ? ( +

- Local URLs are tested from the backend container. Use host.docker.internal instead of - localhost. + {selectedProtocol?.description} Base URL is explicit and editable; no provider presets + are required. Local URLs are tested from the backend container, so use + host.docker.internal instead of localhost.

- ) : isNative ? ( -

- Just paste an API key — {preset?.label} routes through its native endpoint - automatically. After adding, hit Discover (or add model IDs manually). -

- ) : preset?.protocol === "OPENAI_COMPATIBLE" ? ( -

- Enter any OpenAI-compatible endpoint (OpenRouter, Together, Groq, vLLM, LM Studio…). - After adding, hit Discover to list models. -

- ) : null} +
+ + {showAdvancedProvider ? ( +
+ + setLitellmProvider(event.target.value)} + placeholder={protocolDefaultProvider} + /> +

+ Leave empty to use the protocol default. Set this for more accurate LiteLLM + capabilities/costs, for example openrouter, groq, gemini, or azure. +

+
+ ) : null} +
+
{connections.map((connection) => ( diff --git a/surfsense_web/contracts/types/model-connections.types.ts b/surfsense_web/contracts/types/model-connections.types.ts index dcc875251..7a37799c4 100644 --- a/surfsense_web/contracts/types/model-connections.types.ts +++ b/surfsense_web/contracts/types/model-connections.types.ts @@ -1,6 +1,6 @@ import { z } from "zod"; -export const connectionProtocolEnum = z.enum(["OLLAMA", "OPENAI_COMPATIBLE", "NATIVE"]); +export const connectionProtocolEnum = z.enum(["OLLAMA", "OPENAI_COMPATIBLE", "ANTHROPIC"]); export const connectionScopeEnum = z.enum(["GLOBAL", "SEARCH_SPACE", "USER"]); export const modelSourceEnum = z.enum(["DISCOVERED", "MANUAL"]); @@ -32,7 +32,7 @@ export const modelRead = z.object({ export const connectionRead = z.object({ id: z.number(), protocol: z.union([connectionProtocolEnum, z.string()]), - native_provider: z.string().nullable().optional(), + litellm_provider: z.string().nullable().optional(), base_url: z.string().nullable().optional(), extra: z.record(z.string(), z.any()).default({}), scope: z.union([connectionScopeEnum, z.string()]), @@ -49,7 +49,7 @@ export const connectionRead = z.object({ export const connectionCreateRequest = z.object({ protocol: connectionProtocolEnum, - native_provider: z.string().nullable().optional(), + litellm_provider: z.string().nullable().optional(), base_url: z.string().nullable().optional(), api_key: z.string().nullable().optional(), extra: z.record(z.string(), z.any()).default({}), @@ -59,7 +59,7 @@ export const connectionCreateRequest = z.object({ }); export const connectionUpdateRequest = z.object({ - native_provider: z.string().nullable().optional(), + litellm_provider: z.string().nullable().optional(), base_url: z.string().nullable().optional(), api_key: z.string().nullable().optional(), extra: z.record(z.string(), z.any()).optional(), diff --git a/surfsense_web/hooks/use-automation-eligible-models.ts b/surfsense_web/hooks/use-automation-eligible-models.ts index e75235c56..f8b264162 100644 --- a/surfsense_web/hooks/use-automation-eligible-models.ts +++ b/surfsense_web/hooks/use-automation-eligible-models.ts @@ -51,7 +51,7 @@ function buildKind( id: model.id, name: model.display_name || model.model_id, modelName: model.model_id, - provider: connection.native_provider || connection.protocol, + provider: connection.litellm_provider || connection.protocol, isBYOK, }); diff --git a/surfsense_web/lib/apis/model-connections-api.service.ts b/surfsense_web/lib/apis/model-connections-api.service.ts index 7d0f0f59c..92eec8e61 100644 --- a/surfsense_web/lib/apis/model-connections-api.service.ts +++ b/surfsense_web/lib/apis/model-connections-api.service.ts @@ -1,11 +1,13 @@ import { type ConnectionCreateRequest, + type ConnectionRead, type ConnectionUpdateRequest, connectionCreateRequest, connectionListResponse, connectionRead, connectionUpdateRequest, type ModelCreateRequest, + type ModelRead, type ModelRoles, type ModelUpdateRequest, modelCreateRequest, @@ -13,24 +15,25 @@ import { modelRead, modelRoles, modelUpdateRequest, + type VerifyConnectionResponse, verifyConnectionResponse, } from "@/contracts/types/model-connections.types"; import { ValidationError } from "../error"; import { baseApiService } from "./base-api.service"; class ModelConnectionsApiService { - getGlobalConnections = async () => { + getGlobalConnections = async (): Promise => { return baseApiService.get(`/api/v1/global-model-connections`, connectionListResponse); }; - getConnections = async (searchSpaceId: number) => { + getConnections = async (searchSpaceId: number): Promise => { return baseApiService.get( `/api/v1/model-connections?search_space_id=${searchSpaceId}`, connectionListResponse ); }; - createConnection = async (request: ConnectionCreateRequest) => { + createConnection = async (request: ConnectionCreateRequest): Promise => { const parsed = connectionCreateRequest.safeParse(request); if (!parsed.success) { throw new ValidationError(parsed.error.issues.map((issue) => issue.message).join(", ")); @@ -40,7 +43,10 @@ class ModelConnectionsApiService { }); }; - updateConnection = async (id: number, request: ConnectionUpdateRequest) => { + updateConnection = async ( + id: number, + request: ConnectionUpdateRequest + ): Promise => { const parsed = connectionUpdateRequest.safeParse(request); if (!parsed.success) { throw new ValidationError(parsed.error.issues.map((issue) => issue.message).join(", ")); @@ -54,15 +60,18 @@ class ModelConnectionsApiService { return baseApiService.delete(`/api/v1/model-connections/${id}`, undefined); }; - verifyConnection = async (id: number) => { + verifyConnection = async (id: number): Promise => { return baseApiService.post(`/api/v1/model-connections/${id}/verify`, verifyConnectionResponse); }; - discoverModels = async (id: number) => { + discoverModels = async (id: number): Promise => { return baseApiService.post(`/api/v1/model-connections/${id}/discover`, modelListResponse); }; - addManualModel = async (connectionId: number, request: ModelCreateRequest) => { + addManualModel = async ( + connectionId: number, + request: ModelCreateRequest + ): Promise => { const parsed = modelCreateRequest.safeParse(request); if (!parsed.success) { throw new ValidationError(parsed.error.issues.map((issue) => issue.message).join(", ")); @@ -72,7 +81,7 @@ class ModelConnectionsApiService { }); }; - updateModel = async (id: number, request: ModelUpdateRequest) => { + updateModel = async (id: number, request: ModelUpdateRequest): Promise => { const parsed = modelUpdateRequest.safeParse(request); if (!parsed.success) { throw new ValidationError(parsed.error.issues.map((issue) => issue.message).join(", ")); @@ -82,15 +91,15 @@ class ModelConnectionsApiService { }); }; - testModel = async (id: number) => { + testModel = async (id: number): Promise => { return baseApiService.post(`/api/v1/models/${id}/test`, verifyConnectionResponse); }; - getModelRoles = async (searchSpaceId: number) => { + getModelRoles = async (searchSpaceId: number): Promise => { return baseApiService.get(`/api/v1/search-spaces/${searchSpaceId}/model-roles`, modelRoles); }; - updateModelRoles = async (searchSpaceId: number, roles: ModelRoles) => { + updateModelRoles = async (searchSpaceId: number, roles: ModelRoles): Promise => { return baseApiService.put(`/api/v1/search-spaces/${searchSpaceId}/model-roles`, modelRoles, { body: roles, }); From 5d5d574550db44419cd5a3e112a2329d8c8923ea Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Fri, 12 Jun 2026 02:17:22 +0530 Subject: [PATCH 22/59] refactor(model-connections): move backend model connections to provider capabilities --- .../versions/156_add_model_connections.py | 121 ++++--- .../deliverables/tools/generate_image.py | 7 +- .../app/agents/chat/runtime/llm_config.py | 43 ++- surfsense_backend/app/app.py | 2 - surfsense_backend/app/celery_app.py | 2 - surfsense_backend/app/config/__init__.py | 50 +-- surfsense_backend/app/db.py | 17 +- .../app/routes/anonymous_chat_routes.py | 4 +- .../app/routes/image_generation_routes.py | 7 +- .../app/routes/model_connections_routes.py | 98 ++++-- .../app/routes/new_chat_routes.py | 13 +- .../app/routes/new_llm_config_routes.py | 6 +- .../app/routes/search_spaces_routes.py | 6 +- .../app/routes/vision_llm_routes.py | 2 +- surfsense_backend/app/schemas/__init__.py | 1 + .../app/schemas/model_connections.py | 33 +- .../app/services/auto_model_pin_service.py | 30 +- .../app/services/global_model_catalog.py | 28 +- surfsense_backend/app/services/llm_service.py | 14 +- .../app/services/model_capabilities.py | 36 +++ .../app/services/model_connection_service.py | 301 +++++++++++------- .../app/services/model_list_service.py | 23 +- .../app/services/model_resolver.py | 50 +-- .../openrouter_integration_service.py | 90 ++---- .../services/openrouter_model_normalizer.py | 121 +++++++ .../app/services/pricing_registration.py | 16 +- .../app/services/provider_capabilities.py | 12 +- .../app/services/provider_registry.py | 98 ++++++ .../app/services/quality_score.py | 2 +- .../flows/new_chat/llm_capability.py | 2 +- .../chat/streaming/flows/shared/llm_bundle.py | 13 +- 31 files changed, 772 insertions(+), 476 deletions(-) create mode 100644 surfsense_backend/app/services/model_capabilities.py create mode 100644 surfsense_backend/app/services/openrouter_model_normalizer.py create mode 100644 surfsense_backend/app/services/provider_registry.py diff --git a/surfsense_backend/alembic/versions/156_add_model_connections.py b/surfsense_backend/alembic/versions/156_add_model_connections.py index 185debca4..64614db99 100644 --- a/surfsense_backend/alembic/versions/156_add_model_connections.py +++ b/surfsense_backend/alembic/versions/156_add_model_connections.py @@ -17,13 +17,6 @@ branch_labels: str | Sequence[str] | None = None depends_on: str | Sequence[str] | None = None -connection_protocol = postgresql.ENUM( - "OLLAMA", - "OPENAI_COMPATIBLE", - "ANTHROPIC", - name="connectionprotocol", - create_type=False, -) connection_scope = postgresql.ENUM( "GLOBAL", "SEARCH_SPACE", @@ -73,36 +66,67 @@ def _add_searchspace_column_if_missing(column_name: str) -> None: op.add_column("searchspaces", sa.Column(column_name, sa.Integer(), nullable=True)) +def _drop_column_if_exists(table_name: str, column_name: str) -> None: + if _column_exists(table_name, column_name): + op.drop_column(table_name, column_name) + + +def _drop_index_if_exists(table_name: str, index_name: str) -> None: + if _index_exists(table_name, index_name): + op.drop_index(index_name, table_name=table_name) + + def upgrade() -> None: bind = op.get_bind() - connection_protocol.create(bind, checkfirst=True) - op.execute("ALTER TYPE connectionprotocol ADD VALUE IF NOT EXISTS 'ANTHROPIC'") connection_scope.create(bind, checkfirst=True) model_source.create(bind, checkfirst=True) if _table_exists("connections"): - if _column_exists("connections", "native_provider") and not _column_exists( - "connections", "litellm_provider" + if _column_exists("connections", "litellm_provider") and not _column_exists( + "connections", "provider" + ): + op.alter_column( + "connections", + "litellm_provider", + new_column_name="provider", + existing_type=sa.String(length=100), + existing_nullable=True, + ) + op.alter_column( + "connections", + "provider", + existing_type=sa.String(length=100), + nullable=False, + ) + elif _column_exists("connections", "native_provider") and not _column_exists( + "connections", "provider" ): op.alter_column( "connections", "native_provider", - new_column_name="litellm_provider", + new_column_name="provider", existing_type=sa.String(length=100), existing_nullable=True, ) - elif not _column_exists("connections", "litellm_provider"): + op.alter_column( + "connections", + "provider", + existing_type=sa.String(length=100), + nullable=False, + ) + elif not _column_exists("connections", "provider"): op.add_column( "connections", - sa.Column("litellm_provider", sa.String(length=100), nullable=True), + sa.Column("provider", sa.String(length=100), nullable=False), ) + _drop_index_if_exists("connections", "ix_connections_protocol") + _drop_column_if_exists("connections", "protocol") else: op.create_table( "connections", sa.Column("id", sa.Integer(), nullable=False), sa.Column("created_at", sa.DateTime(timezone=True), nullable=False), - sa.Column("protocol", connection_protocol, nullable=False), - sa.Column("litellm_provider", sa.String(length=100), nullable=True), + sa.Column("provider", sa.String(length=100), nullable=False), sa.Column("base_url", sa.String(length=500), nullable=True), sa.Column("api_key", sa.String(), nullable=True), sa.Column( @@ -131,18 +155,20 @@ def upgrade() -> None: sa.PrimaryKeyConstraint("id"), ) if _index_exists("connections", "ix_connections_native_provider") and not _index_exists( - "connections", "ix_connections_litellm_provider" + "connections", "ix_connections_provider" ): op.execute( "ALTER INDEX ix_connections_native_provider " - "RENAME TO ix_connections_litellm_provider" + "RENAME TO ix_connections_provider" ) - _create_index_if_missing("ix_connections_protocol", "connections", ["protocol"]) - _create_index_if_missing( - "ix_connections_litellm_provider", - "connections", - ["litellm_provider"], - ) + if _index_exists("connections", "ix_connections_litellm_provider") and not _index_exists( + "connections", "ix_connections_provider" + ): + op.execute( + "ALTER INDEX ix_connections_litellm_provider " + "RENAME TO ix_connections_provider" + ) + _create_index_if_missing("ix_connections_provider", "connections", ["provider"]) _create_index_if_missing("ix_connections_scope", "connections", ["scope"]) if not _table_exists("models"): @@ -159,24 +185,11 @@ def upgrade() -> None: server_default="DISCOVERED", nullable=False, ), - sa.Column( - "capabilities", - postgresql.JSONB(astext_type=sa.Text()), - server_default=sa.text("'{}'::jsonb"), - nullable=False, - ), - sa.Column( - "capabilities_declared", - postgresql.JSONB(astext_type=sa.Text()), - server_default=sa.text("'{}'::jsonb"), - nullable=False, - ), - sa.Column( - "capabilities_verified", - postgresql.JSONB(astext_type=sa.Text()), - server_default=sa.text("'{}'::jsonb"), - nullable=False, - ), + sa.Column("supports_chat", sa.Boolean(), nullable=True), + sa.Column("max_input_tokens", sa.Integer(), nullable=True), + sa.Column("supports_image_input", sa.Boolean(), nullable=True), + sa.Column("supports_tools", sa.Boolean(), nullable=True), + sa.Column("supports_image_generation", sa.Boolean(), nullable=True), sa.Column( "capabilities_override", postgresql.JSONB(astext_type=sa.Text()), @@ -198,6 +211,24 @@ def upgrade() -> None: "connection_id", "model_id", name="uq_models_connection_model_id" ), ) + else: + if not _column_exists("models", "supports_chat"): + op.add_column("models", sa.Column("supports_chat", sa.Boolean(), nullable=True)) + if not _column_exists("models", "max_input_tokens"): + op.add_column("models", sa.Column("max_input_tokens", sa.Integer(), nullable=True)) + if not _column_exists("models", "supports_image_input"): + op.add_column( + "models", sa.Column("supports_image_input", sa.Boolean(), nullable=True) + ) + if not _column_exists("models", "supports_tools"): + op.add_column("models", sa.Column("supports_tools", sa.Boolean(), nullable=True)) + if not _column_exists("models", "supports_image_generation"): + op.add_column( + "models", sa.Column("supports_image_generation", sa.Boolean(), nullable=True) + ) + _drop_column_if_exists("models", "capabilities") + _drop_column_if_exists("models", "capabilities_declared") + _drop_column_if_exists("models", "capabilities_verified") _create_index_if_missing("ix_models_connection_id", "models", ["connection_id"]) _create_index_if_missing("ix_models_model_id", "models", ["model_id"]) _create_index_if_missing("ix_models_billing_tier", "models", ["billing_tier"]) @@ -206,6 +237,8 @@ def upgrade() -> None: _add_searchspace_column_if_missing("image_gen_model_id") _add_searchspace_column_if_missing("vision_model_id") + op.execute("DROP TYPE IF EXISTS connectionprotocol") + def downgrade() -> None: op.drop_column("searchspaces", "vision_model_id") @@ -218,11 +251,9 @@ def downgrade() -> None: op.drop_table("models") op.drop_index(op.f("ix_connections_scope"), table_name="connections") - op.drop_index(op.f("ix_connections_litellm_provider"), table_name="connections") - op.drop_index(op.f("ix_connections_protocol"), table_name="connections") + op.drop_index(op.f("ix_connections_provider"), table_name="connections") op.drop_table("connections") bind = op.get_bind() model_source.drop(bind, checkfirst=True) connection_scope.drop(bind, checkfirst=True) - connection_protocol.drop(bind, checkfirst=True) diff --git a/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/deliverables/tools/generate_image.py b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/deliverables/tools/generate_image.py index fda327750..505831faa 100644 --- a/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/deliverables/tools/generate_image.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/deliverables/tools/generate_image.py @@ -29,6 +29,7 @@ from app.services.auto_model_pin_service import ( auto_model_candidates, choose_auto_model_candidate, ) +from app.services.model_capabilities import has_capability from app.services.model_resolver import to_litellm from app.utils.signed_image_urls import generate_image_token @@ -146,9 +147,7 @@ def create_generate_image_tool( if config_id < 0: global_model = _get_global_model(config_id) - if not global_model or not ( - global_model.get("capabilities") or {} - ).get("image_gen"): + if not global_model or not has_capability(global_model, "image_gen"): err = f"Image generation model {config_id} not found" return _failed({"error": err}, error=err) global_connection = _get_global_connection( @@ -191,7 +190,7 @@ def create_generate_image_tool( ): err = f"Image generation model {config_id} not found" return _failed({"error": err}, error=err) - if not (db_model.capabilities or {}).get("image_gen"): + if not has_capability(db_model, "image_gen"): err = f"Model {config_id} is not image-generation capable" return _failed({"error": err}, error=err) diff --git a/surfsense_backend/app/agents/chat/runtime/llm_config.py b/surfsense_backend/app/agents/chat/runtime/llm_config.py index b9344e001..efc188df8 100644 --- a/surfsense_backend/app/agents/chat/runtime/llm_config.py +++ b/surfsense_backend/app/agents/chat/runtime/llm_config.py @@ -49,16 +49,19 @@ def _sanitize_messages(messages: list[BaseMessage]) -> list[BaseMessage]: reject the blank text. The OpenAI spec says ``content`` should be ``null`` when an assistant message only carries tool calls. """ + sanitized: list[BaseMessage] = [] for msg in messages: - if isinstance(msg.content, list): - msg.content = _sanitize_content(msg.content) + next_msg = msg.model_copy(deep=True) + if isinstance(next_msg.content, list): + next_msg.content = _sanitize_content(next_msg.content) if ( - isinstance(msg, AIMessage) - and (not msg.content or msg.content == "") - and getattr(msg, "tool_calls", None) + isinstance(next_msg, AIMessage) + and (not next_msg.content or next_msg.content == "") + and getattr(next_msg, "tool_calls", None) ): - msg.content = None # type: ignore[assignment] - return messages + next_msg.content = None # type: ignore[assignment] + sanitized.append(next_msg) + return sanitized class SanitizedChatLiteLLM(ChatLiteLLM): @@ -89,6 +92,22 @@ class SanitizedChatLiteLLM(ChatLiteLLM): ): yield chunk + async def _agenerate( + self, + messages: list[BaseMessage], + stop: list[str] | None = None, + run_manager: AsyncCallbackManagerForLLMRun | None = None, + stream: bool | None = None, + **kwargs: Any, + ) -> ChatResult: + return await super()._agenerate( + _sanitize_messages(messages), + stop=stop, + run_manager=run_manager, + stream=stream, + **kwargs, + ) + def _attach_model_profile(llm: ChatLiteLLM, model_string: str) -> None: """Attach a ``profile`` dict to ChatLiteLLM with model context metadata.""" @@ -210,7 +229,7 @@ class AgentConfig: # BYOK rows have no curated flag; ask LiteLLM (default-allow on # unknown). The streaming safety net still blocks explicit text-only. supports_image_input=derive_supports_image_input( - litellm_provider=provider_value.lower(), + provider=provider_value.lower(), model_name=config.model_name, base_model=base_model, custom_provider=config.custom_provider, @@ -229,7 +248,7 @@ class AgentConfig: system_instructions = yaml_config.get("system_instructions", "") - provider = yaml_config.get("litellm_provider", "") + provider = yaml_config.get("provider") or yaml_config.get("litellm_provider", "") model_name = yaml_config.get("model_name", "") custom_provider = yaml_config.get("custom_provider") litellm_params = yaml_config.get("litellm_params") or {} @@ -245,7 +264,7 @@ class AgentConfig: supports_image_input = bool(yaml_config.get("supports_image_input")) else: supports_image_input = derive_supports_image_input( - litellm_provider=provider, + provider=provider, model_name=model_name, base_model=base_model, custom_provider=custom_provider, @@ -396,8 +415,8 @@ def create_chat_litellm_from_config(llm_config: dict) -> ChatLiteLLM | None: if llm_config.get("custom_provider"): model_string = f"{llm_config['custom_provider']}/{llm_config['model_name']}" else: - litellm_provider = llm_config.get("litellm_provider", "openai") - model_string = f"{litellm_provider}/{llm_config['model_name']}" + provider = llm_config.get("provider") or llm_config.get("litellm_provider", "openai") + model_string = f"{provider}/{llm_config['model_name']}" litellm_kwargs = { "model": model_string, diff --git a/surfsense_backend/app/app.py b/surfsense_backend/app/app.py index d3f5dce2a..6dfe6a776 100644 --- a/surfsense_backend/app/app.py +++ b/surfsense_backend/app/app.py @@ -33,7 +33,6 @@ from app.config import ( initialize_llm_router, initialize_openrouter_integration, initialize_pricing_registration, - initialize_vision_llm_router, ) from app.db import User, create_db_and_tables, get_async_session from app.exceptions import GENERIC_5XX_MESSAGE, ISSUES_URL, SurfSenseError @@ -622,7 +621,6 @@ async def lifespan(app: FastAPI): initialize_pricing_registration() initialize_llm_router() initialize_image_gen_router() - initialize_vision_llm_router() # Phase 1.7 — JIT warmup. Bounded so a stuck warmup never delays # worker readiness. ``shield`` so Uvicorn cancelling startup diff --git a/surfsense_backend/app/celery_app.py b/surfsense_backend/app/celery_app.py index 0e852b801..7ab98203e 100644 --- a/surfsense_backend/app/celery_app.py +++ b/surfsense_backend/app/celery_app.py @@ -115,14 +115,12 @@ def init_worker(**kwargs): initialize_llm_router, initialize_openrouter_integration, initialize_pricing_registration, - initialize_vision_llm_router, ) initialize_openrouter_integration() initialize_pricing_registration() initialize_llm_router() initialize_image_gen_router() - initialize_vision_llm_router() # Celery configuration, sourced from the central Config singleton diff --git a/surfsense_backend/app/config/__init__.py b/surfsense_backend/app/config/__init__.py index fd8b29116..eb2a7a18c 100644 --- a/surfsense_backend/app/config/__init__.py +++ b/surfsense_backend/app/config/__init__.py @@ -103,7 +103,7 @@ def load_global_llm_configs(): else None ) cfg["supports_image_input"] = derive_supports_image_input( - litellm_provider=cfg.get("litellm_provider"), + provider=cfg.get("provider") or cfg.get("litellm_provider"), model_name=cfg.get("model_name"), base_model=base_model, custom_provider=cfg.get("custom_provider"), @@ -122,7 +122,7 @@ def load_global_llm_configs(): # Stamp Auto (Fastest) ranking metadata. YAML configs are always # Tier A — operator-curated, locked first when premium-eligible. # The OpenRouter refresh tick later re-stamps health for any cfg - # whose litellm_provider == "openrouter" via _enrich_health. + # whose provider == "openrouter" via _enrich_health. try: from app.services.quality_score import static_score_yaml @@ -132,7 +132,7 @@ def load_global_llm_configs(): cfg["quality_score_static"] = static_q cfg["quality_score"] = static_q cfg["quality_score_health"] = None - # YAML cfgs whose litellm_provider is openrouter are also subject + # YAML cfgs whose provider is openrouter are also subject # to health gating against their own /endpoints data — a # hand-picked dead OR model is still dead. _enrich_health # re-stamps health_gated for them on the next refresh tick. @@ -362,8 +362,8 @@ def initialize_openrouter_integration(): else: print("Info: OpenRouter integration enabled but no models fetched") - # Image generation + vision LLM emissions are opt-in (issue L). - # Both reuse the catalogue already cached by ``service.initialize`` + # Image generation emissions reuse the catalogue already cached by + # ``service.initialize`` # so we don't make additional network calls here. if settings.get("image_generation_enabled"): try: @@ -377,18 +377,6 @@ def initialize_openrouter_integration(): except Exception as e: print(f"Warning: Failed to inject OpenRouter image-gen configs: {e}") - if settings.get("vision_enabled"): - try: - vision_configs = service.get_vision_llm_configs() - if vision_configs: - config.GLOBAL_VISION_LLM_CONFIGS.extend(vision_configs) - print( - f"Info: OpenRouter integration added {len(vision_configs)} " - f"vision LLM models" - ) - except Exception as e: - print(f"Warning: Failed to inject OpenRouter vision-LLM configs: {e}") - refresh_global_model_catalog() except Exception as e: print(f"Warning: Failed to initialize OpenRouter integration: {e}") @@ -399,7 +387,6 @@ def materialize_global_configs(): return materialize_global_model_catalog( chat_configs=getattr(config, "GLOBAL_LLM_CONFIGS", []), - vision_configs=getattr(config, "GLOBAL_VISION_LLM_CONFIGS", []), image_configs=getattr(config, "GLOBAL_IMAGE_GEN_CONFIGS", []), ) @@ -493,29 +480,9 @@ def initialize_image_gen_router(): def initialize_vision_llm_router(): - vision_configs = load_global_vision_llm_configs() - # Reuse the router settings already parsed at Config construction. The - # *configs* list is intentionally re-read from YAML (it must exclude the - # OpenRouter-injected dynamic models held in config.GLOBAL_VISION_LLM_CONFIGS). - router_settings = config.VISION_LLM_ROUTER_SETTINGS - - if not vision_configs: - print( - "Info: No global vision LLM configs found, " - "Vision LLM Auto mode will not be available" - ) - return - - try: - from app.services.vision_llm_router_service import VisionLLMRouterService - - VisionLLMRouterService.initialize(vision_configs, router_settings) - print( - f"Info: Vision LLM Router initialized with {len(vision_configs)} models " - f"(strategy: {router_settings.get('routing_strategy', 'usage-based-routing')})" - ) - except Exception as e: - print(f"Warning: Failed to initialize Vision LLM Router: {e}") + # Retired: vision Auto now uses shared capability-filtered model selection + # over GLOBAL/BYOK chat models with supports_image_input=true. + return class Config: @@ -874,7 +841,6 @@ class Config: GLOBAL_CONNECTIONS, GLOBAL_MODELS = _materialize_global_model_catalog( chat_configs=GLOBAL_LLM_CONFIGS, - vision_configs=GLOBAL_VISION_LLM_CONFIGS, image_configs=GLOBAL_IMAGE_GEN_CONFIGS, ) del _materialize_global_model_catalog diff --git a/surfsense_backend/app/db.py b/surfsense_backend/app/db.py index 4c628b05a..9053d7055 100644 --- a/surfsense_backend/app/db.py +++ b/surfsense_backend/app/db.py @@ -280,12 +280,6 @@ class VisionProvider(StrEnum): CUSTOM = "CUSTOM" -class ConnectionProtocol(StrEnum): - OLLAMA = "OLLAMA" - OPENAI_COMPATIBLE = "OPENAI_COMPATIBLE" - ANTHROPIC = "ANTHROPIC" - - class ConnectionScope(StrEnum): GLOBAL = "GLOBAL" SEARCH_SPACE = "SEARCH_SPACE" @@ -1662,8 +1656,7 @@ class Report(BaseModel, TimestampMixin): class Connection(BaseModel, TimestampMixin): __tablename__ = "connections" - protocol = Column(SQLAlchemyEnum(ConnectionProtocol), nullable=False, index=True) - litellm_provider = Column(String(100), nullable=True, index=True) + provider = Column(String(100), nullable=False, index=True) base_url = Column(String(500), nullable=True) api_key = Column(String, nullable=True) extra = Column(JSONB, nullable=False, default=dict, server_default="{}") @@ -1715,9 +1708,11 @@ class Model(BaseModel, TimestampMixin): default=ModelSource.DISCOVERED, server_default=ModelSource.DISCOVERED.value, ) - capabilities = Column(JSONB, nullable=False, default=dict, server_default="{}") - capabilities_declared = Column(JSONB, nullable=False, default=dict, server_default="{}") - capabilities_verified = Column(JSONB, nullable=False, default=dict, server_default="{}") + supports_chat = Column(Boolean, nullable=True) + max_input_tokens = Column(Integer, nullable=True) + supports_image_input = Column(Boolean, nullable=True) + supports_tools = Column(Boolean, nullable=True) + supports_image_generation = Column(Boolean, nullable=True) capabilities_override = Column(JSONB, nullable=False, default=dict, server_default="{}") embedding_dimension = Column(Integer, nullable=True) enabled = Column(Boolean, nullable=False, default=True, server_default="true") diff --git a/surfsense_backend/app/routes/anonymous_chat_routes.py b/surfsense_backend/app/routes/anonymous_chat_routes.py index aba1a3a12..84420e738 100644 --- a/surfsense_backend/app/routes/anonymous_chat_routes.py +++ b/surfsense_backend/app/routes/anonymous_chat_routes.py @@ -132,7 +132,7 @@ async def list_anonymous_models(): id=cfg.get("id", 0), name=cfg.get("name", ""), description=cfg.get("description"), - provider=cfg.get("litellm_provider", ""), + provider=cfg.get("provider") or cfg.get("litellm_provider", ""), model_name=cfg.get("model_name", ""), billing_tier=cfg.get("billing_tier", "free"), is_premium=cfg.get("billing_tier", "free") == "premium", @@ -161,7 +161,7 @@ async def get_anonymous_model(slug: str): id=cfg.get("id", 0), name=cfg.get("name", ""), description=cfg.get("description"), - provider=cfg.get("litellm_provider", ""), + provider=cfg.get("provider") or cfg.get("litellm_provider", ""), model_name=cfg.get("model_name", ""), billing_tier=cfg.get("billing_tier", "free"), is_premium=cfg.get("billing_tier", "free") == "premium", diff --git a/surfsense_backend/app/routes/image_generation_routes.py b/surfsense_backend/app/routes/image_generation_routes.py index 5be1cedf1..e8f14bd71 100644 --- a/surfsense_backend/app/routes/image_generation_routes.py +++ b/surfsense_backend/app/routes/image_generation_routes.py @@ -52,6 +52,7 @@ from app.services.auto_model_pin_service import ( choose_auto_model_candidate, ) from app.services.model_resolver import to_litellm +from app.services.model_capabilities import has_capability from app.users import current_active_user from app.utils.rbac import check_permission from app.utils.signed_image_urls import verify_image_token @@ -166,7 +167,7 @@ async def _execute_image_generation( if config_id < 0: global_model = _get_global_model(config_id) - if not global_model or not (global_model.get("capabilities") or {}).get("image_gen"): + if not global_model or not has_capability(global_model, "image_gen"): raise ValueError(f"Global image generation model {config_id} not found") global_connection = _get_global_connection(global_model["connection_id"]) if not global_connection: @@ -200,7 +201,7 @@ async def _execute_image_generation( raise ValueError(f"Image generation model {config_id} not found") if conn.user_id is not None and conn.user_id != search_space.user_id: raise ValueError(f"Image generation model {config_id} not found") - if not (db_model.capabilities or {}).get("image_gen"): + if not has_capability(db_model, "image_gen"): raise ValueError(f"Model {config_id} is not image-generation capable") model_string, resolved_kwargs = to_litellm( @@ -272,7 +273,7 @@ async def get_global_image_gen_configs( "id": cfg.get("id"), "name": cfg.get("name"), "description": cfg.get("description"), - "provider": cfg.get("litellm_provider"), + "provider": cfg.get("provider") or cfg.get("litellm_provider"), "custom_provider": cfg.get("custom_provider"), "model_name": cfg.get("model_name"), "api_base": cfg.get("api_base") or None, diff --git a/surfsense_backend/app/routes/model_connections_routes.py b/surfsense_backend/app/routes/model_connections_routes.py index cae951c3a..ecb86711e 100644 --- a/surfsense_backend/app/routes/model_connections_routes.py +++ b/surfsense_backend/app/routes/model_connections_routes.py @@ -8,7 +8,6 @@ from sqlalchemy.orm import selectinload from app.config import config from app.db import ( Connection, - ConnectionProtocol, ConnectionScope, Model, ModelSource, @@ -22,6 +21,7 @@ from app.schemas import ( ConnectionRead, ConnectionUpdate, ModelCreate, + ModelProviderRead, ModelRead, ModelRolesRead, ModelRolesUpdate, @@ -34,6 +34,8 @@ from app.services.model_connection_service import ( persist_verification, test_model, ) +from app.services.model_capabilities import has_capability +from app.services.provider_registry import REGISTRY from app.users import current_active_user from app.utils.rbac import check_permission @@ -41,16 +43,6 @@ router = APIRouter() logger = logging.getLogger(__name__) -def _default_litellm_provider(protocol: ConnectionProtocol | str) -> str: - protocol_value = getattr(protocol, "value", str(protocol)) - defaults = { - ConnectionProtocol.OLLAMA.value: "ollama_chat", - ConnectionProtocol.ANTHROPIC.value: "anthropic", - ConnectionProtocol.OPENAI_COMPATIBLE.value: "openai", - } - return defaults.get(protocol_value, "openai") - - def _model_read(model: Model | dict) -> ModelRead: return ModelRead.model_validate(model) @@ -68,8 +60,7 @@ def _connection_read(conn: Connection | dict, models: list[Model | dict] | None return ConnectionRead( id=conn.id, - protocol=conn.protocol, - litellm_provider=conn.litellm_provider, + provider=conn.provider, base_url=conn.base_url, extra=conn.extra or {}, scope=conn.scope, @@ -85,6 +76,60 @@ def _connection_read(conn: Connection | dict, models: list[Model | dict] | None ) +def _apply_model_facts(model: Model, facts: dict) -> None: + model.supports_chat = facts.get("supports_chat") + model.max_input_tokens = facts.get("max_input_tokens") + model.supports_image_input = facts.get("supports_image_input") + model.supports_tools = facts.get("supports_tools") + model.supports_image_generation = facts.get("supports_image_generation") + + +def _default_model_for(models: list[Model], capability: str) -> int | None: + for model in models: + if model.enabled and has_capability(model, capability): + return model.id + return None + + +async def _default_unset_roles( + session: AsyncSession, + conn: Connection, + models: list[Model], +) -> None: + if conn.scope != ConnectionScope.SEARCH_SPACE or conn.search_space_id is None: + return + search_space = await _get_search_space(session, conn.search_space_id) + if search_space.chat_model_id is None: + search_space.chat_model_id = _default_model_for(models, "chat") + if search_space.vision_model_id is None: + vision_default = None + if search_space.chat_model_id: + chat_model = next((m for m in models if m.id == search_space.chat_model_id), None) + if chat_model and has_capability(chat_model, "vision"): + vision_default = chat_model.id + search_space.vision_model_id = vision_default or _default_model_for(models, "vision") + if search_space.image_gen_model_id is None: + search_space.image_gen_model_id = _default_model_for(models, "image_gen") + + +@router.get("/model-providers", response_model=list[ModelProviderRead]) +async def list_model_providers(user: User = Depends(current_active_user)): + del user + local_only = {"ollama_chat", "lm_studio"} + return [ + ModelProviderRead( + provider=provider, + transport=spec.transport.value, + discovery=spec.discovery, + default_base_url=spec.default_base_url, + base_url_required=spec.base_url_required, + auth_style=spec.auth_style, + local_only=provider in local_only, + ) + for provider, spec in sorted(REGISTRY.items()) + ] + + async def _get_search_space(session: AsyncSession, search_space_id: int) -> SearchSpace: result = await session.execute(select(SearchSpace).where(SearchSpace.id == search_space_id)) search_space = result.scalars().first() @@ -180,8 +225,6 @@ async def create_connection( "You don't have permission to create model connections in this search space", ) payload = data.model_dump(exclude={"search_space_id"}) - if not payload.get("litellm_provider"): - payload["litellm_provider"] = _default_litellm_provider(data.protocol) conn = Connection( **payload, @@ -254,24 +297,21 @@ async def discover_connection_models( model_id=item["model_id"], display_name=item.get("display_name"), source=item["source"], - capabilities=item["capabilities"], - capabilities_declared=item["capabilities"], - capabilities_verified={}, capabilities_override={}, enabled=False, catalog=item.get("metadata") or {}, ) + _apply_model_facts(db_model, item) session.add(db_model) else: db_model.display_name = item.get("display_name") or db_model.display_name - db_model.capabilities_declared = item["capabilities"] - db_model.capabilities = { - **item["capabilities"], - **(db_model.capabilities_override or {}), - } + _apply_model_facts(db_model, item) db_model.catalog = item.get("metadata") or db_model.catalog await session.commit() conn = await _load_connection(session, connection_id) + await _default_unset_roles(session, conn, list(conn.models)) + await session.commit() + conn = await _load_connection(session, connection_id) return [_model_read(model) for model in conn.models] @@ -297,16 +337,17 @@ async def add_manual_model( model_id=model_id, display_name=data.display_name or None, source=ModelSource.MANUAL, - capabilities=capabilities, - capabilities_declared=capabilities, - capabilities_verified={}, capabilities_override={}, enabled=True, catalog={}, ) + _apply_model_facts(model, capabilities) session.add(model) await session.commit() await session.refresh(model) + conn = await _load_connection(session, connection_id) + await _default_unset_roles(session, conn, list(conn.models)) + await session.commit() return _model_read(model) @@ -327,11 +368,6 @@ async def update_model( update = data.model_dump(exclude_unset=True) for key, value in update.items(): setattr(model, key, value) - if "capabilities_override" in update: - model.capabilities = { - **(model.capabilities_declared or {}), - **(model.capabilities_override or {}), - } await session.commit() await session.refresh(model) return _model_read(model) diff --git a/surfsense_backend/app/routes/new_chat_routes.py b/surfsense_backend/app/routes/new_chat_routes.py index 0e4e557be..b5bc2571e 100644 --- a/surfsense_backend/app/routes/new_chat_routes.py +++ b/surfsense_backend/app/routes/new_chat_routes.py @@ -1741,12 +1741,11 @@ async def handle_new_chat( if not search_space: raise HTTPException(status_code=404, detail="Search space not found") - # Use agent_llm_id from search space for chat operations - # Positive IDs load from NewLLMConfig database table - # Negative IDs load from YAML global configs - # Falls back to -1 (first global config) if not configured + # Use the converged model-connections role for chat operations. + # Positive IDs load Model + Connection rows; negative IDs load + # virtual GLOBAL models; 0 means Auto. llm_config_id = ( - search_space.agent_llm_id if search_space.agent_llm_id is not None else -1 + search_space.chat_model_id if search_space.chat_model_id is not None else 0 ) # Release the read-transaction so we don't hold ACCESS SHARE locks @@ -2228,7 +2227,7 @@ async def regenerate_response( raise HTTPException(status_code=404, detail="Search space not found") llm_config_id = ( - search_space.agent_llm_id if search_space.agent_llm_id is not None else -1 + search_space.chat_model_id if search_space.chat_model_id is not None else 0 ) # Release the read-transaction so we don't hold ACCESS SHARE locks @@ -2393,7 +2392,7 @@ async def resume_chat( raise HTTPException(status_code=404, detail="Search space not found") llm_config_id = ( - search_space.agent_llm_id if search_space.agent_llm_id is not None else -1 + search_space.chat_model_id if search_space.chat_model_id is not None else 0 ) decisions = [d.model_dump() for d in request.decisions] diff --git a/surfsense_backend/app/routes/new_llm_config_routes.py b/surfsense_backend/app/routes/new_llm_config_routes.py index 531fa8730..adba5b5ae 100644 --- a/surfsense_backend/app/routes/new_llm_config_routes.py +++ b/surfsense_backend/app/routes/new_llm_config_routes.py @@ -57,7 +57,7 @@ def _serialize_byok_config(config: NewLLMConfig) -> NewLLMConfigRead: litellm_params.get("base_model") if isinstance(litellm_params, dict) else None ) supports_image_input = derive_supports_image_input( - litellm_provider=provider_value.lower(), + provider=provider_value.lower(), model_name=config.model_name, base_model=base_model, custom_provider=config.custom_provider, @@ -147,7 +147,7 @@ async def get_global_new_llm_configs( else None ) supports_image_input = derive_supports_image_input( - litellm_provider=cfg.get("litellm_provider"), + provider=cfg.get("provider") or cfg.get("litellm_provider"), model_name=cfg.get("model_name"), base_model=cfg_base_model, custom_provider=cfg.get("custom_provider"), @@ -157,7 +157,7 @@ async def get_global_new_llm_configs( "id": cfg.get("id"), "name": cfg.get("name"), "description": cfg.get("description"), - "provider": cfg.get("litellm_provider"), + "provider": cfg.get("provider") or cfg.get("litellm_provider"), "custom_provider": cfg.get("custom_provider"), "model_name": cfg.get("model_name"), "api_base": cfg.get("api_base") or None, diff --git a/surfsense_backend/app/routes/search_spaces_routes.py b/surfsense_backend/app/routes/search_spaces_routes.py index 2cda04221..7c5fbf28b 100644 --- a/surfsense_backend/app/routes/search_spaces_routes.py +++ b/surfsense_backend/app/routes/search_spaces_routes.py @@ -419,7 +419,7 @@ async def _get_llm_config_by_id( "id": cfg.get("id"), "name": cfg.get("name"), "description": cfg.get("description"), - "provider": cfg.get("litellm_provider"), + "provider": cfg.get("provider") or cfg.get("litellm_provider"), "custom_provider": cfg.get("custom_provider"), "model_name": cfg.get("model_name"), "api_base": cfg.get("api_base"), @@ -490,7 +490,7 @@ async def _get_image_gen_config_by_id( "id": cfg.get("id"), "name": cfg.get("name"), "description": cfg.get("description"), - "provider": cfg.get("litellm_provider"), + "provider": cfg.get("provider") or cfg.get("litellm_provider"), "custom_provider": cfg.get("custom_provider"), "model_name": cfg.get("model_name"), "api_base": cfg.get("api_base") or None, @@ -550,7 +550,7 @@ async def _get_vision_llm_config_by_id( "id": cfg.get("id"), "name": cfg.get("name"), "description": cfg.get("description"), - "provider": cfg.get("litellm_provider"), + "provider": cfg.get("provider") or cfg.get("litellm_provider"), "custom_provider": cfg.get("custom_provider"), "model_name": cfg.get("model_name"), "api_base": cfg.get("api_base") or None, diff --git a/surfsense_backend/app/routes/vision_llm_routes.py b/surfsense_backend/app/routes/vision_llm_routes.py index df218daac..b93d25b9c 100644 --- a/surfsense_backend/app/routes/vision_llm_routes.py +++ b/surfsense_backend/app/routes/vision_llm_routes.py @@ -96,7 +96,7 @@ async def get_global_vision_llm_configs( "id": cfg.get("id"), "name": cfg.get("name"), "description": cfg.get("description"), - "provider": cfg.get("litellm_provider"), + "provider": cfg.get("provider") or cfg.get("litellm_provider"), "custom_provider": cfg.get("custom_provider"), "model_name": cfg.get("model_name"), "api_base": cfg.get("api_base") or None, diff --git a/surfsense_backend/app/schemas/__init__.py b/surfsense_backend/app/schemas/__init__.py index 2a06eca5c..dde9fcef1 100644 --- a/surfsense_backend/app/schemas/__init__.py +++ b/surfsense_backend/app/schemas/__init__.py @@ -49,6 +49,7 @@ from .model_connections import ( ConnectionRead, ConnectionUpdate, ModelCreate, + ModelProviderRead, ModelRead, ModelRolesRead, ModelRolesUpdate, diff --git a/surfsense_backend/app/schemas/model_connections.py b/surfsense_backend/app/schemas/model_connections.py index 306dd63c8..c081a193d 100644 --- a/surfsense_backend/app/schemas/model_connections.py +++ b/surfsense_backend/app/schemas/model_connections.py @@ -4,7 +4,7 @@ from typing import Any from pydantic import BaseModel, ConfigDict, Field -from app.db import ConnectionProtocol, ConnectionScope, ModelSource +from app.db import ConnectionScope, ModelSource class ModelRead(BaseModel): @@ -13,9 +13,11 @@ class ModelRead(BaseModel): model_id: str display_name: str | None = None source: ModelSource | str - capabilities: dict[str, Any] - capabilities_declared: dict[str, Any] = Field(default_factory=dict) - capabilities_verified: dict[str, Any] = Field(default_factory=dict) + supports_chat: bool | None = None + max_input_tokens: int | None = None + supports_image_input: bool | None = None + supports_tools: bool | None = None + supports_image_generation: bool | None = None capabilities_override: dict[str, Any] = Field(default_factory=dict) embedding_dimension: int | None = None enabled: bool @@ -28,8 +30,7 @@ class ModelRead(BaseModel): class ConnectionRead(BaseModel): id: int - protocol: ConnectionProtocol | str - litellm_provider: str | None = None + provider: str base_url: str | None = None extra: dict[str, Any] = Field(default_factory=dict) scope: ConnectionScope | str @@ -47,8 +48,7 @@ class ConnectionRead(BaseModel): class ConnectionCreate(BaseModel): - protocol: ConnectionProtocol - litellm_provider: str | None = Field(None, max_length=100) + provider: str = Field(..., max_length=100) base_url: str | None = Field(None, max_length=500) api_key: str | None = None extra: dict[str, Any] = Field(default_factory=dict) @@ -58,7 +58,7 @@ class ConnectionCreate(BaseModel): class ConnectionUpdate(BaseModel): - litellm_provider: str | None = Field(None, max_length=100) + provider: str | None = Field(None, max_length=100) base_url: str | None = Field(None, max_length=500) api_key: str | None = None extra: dict[str, Any] | None = None @@ -79,9 +79,24 @@ class ModelCreate(BaseModel): class ModelUpdate(BaseModel): display_name: str | None = Field(None, max_length=255) enabled: bool | None = None + supports_chat: bool | None = None + max_input_tokens: int | None = None + supports_image_input: bool | None = None + supports_tools: bool | None = None + supports_image_generation: bool | None = None capabilities_override: dict[str, Any] | None = None +class ModelProviderRead(BaseModel): + provider: str + transport: str + discovery: str + default_base_url: str | None = None + base_url_required: bool + auth_style: str + local_only: bool = False + + class VerifyConnectionResponse(BaseModel): status: str ok: bool diff --git a/surfsense_backend/app/services/auto_model_pin_service.py b/surfsense_backend/app/services/auto_model_pin_service.py index ee8c4b8dc..652029e76 100644 --- a/surfsense_backend/app/services/auto_model_pin_service.py +++ b/surfsense_backend/app/services/auto_model_pin_service.py @@ -27,6 +27,7 @@ from sqlalchemy.orm import selectinload from app.config import config from app.db import Connection, Model, NewChatThread +from app.services.model_capabilities import has_capability from app.services.quality_score import _QUALITY_TOP_K from app.services.token_quota_service import TokenQuotaService @@ -62,18 +63,13 @@ def _is_usable_global_config(cfg: dict) -> bool: return bool( cfg.get("id") is not None and cfg.get("model_name") - and cfg.get("litellm_provider") + and (cfg.get("provider") or cfg.get("litellm_provider")) and cfg.get("api_key") ) def _has_capability(model: dict | Model, capability: str) -> bool: - caps = ( - model.get("capabilities", {}) - if isinstance(model, dict) - else model.capabilities or {} - ) - return bool(caps.get(capability)) + return has_capability(model, capability) def _prune_runtime_cooldowns(now_ts: float | None = None) -> None: @@ -196,7 +192,7 @@ def _cfg_supports_image_input(cfg: dict) -> bool: else None ) return derive_supports_image_input( - litellm_provider=cfg.get("litellm_provider"), + provider=cfg.get("provider") or cfg.get("litellm_provider"), model_name=cfg.get("model_name"), base_model=base_model, custom_provider=cfg.get("custom_provider"), @@ -253,9 +249,13 @@ def _global_candidates( "model_id": model.get("model_id"), "source": "global", "connection": connection, - "capabilities": model.get("capabilities") or {}, + "supports_chat": model.get("supports_chat"), + "supports_image_input": model.get("supports_image_input"), + "supports_tools": model.get("supports_tools"), + "supports_image_generation": model.get("supports_image_generation"), + "capabilities_override": model.get("capabilities_override") or {}, "billing_tier": model.get("billing_tier", "free"), - "litellm_provider": connection.get("litellm_provider"), + "provider": connection.get("provider"), "model_name": model.get("model_id"), "auto_pin_tier": catalog.get("auto_pin_tier") or cfg.get("auto_pin_tier") @@ -310,9 +310,13 @@ async def _db_candidates( "model_id": model.model_id, "source": "db", "connection": conn, - "capabilities": model.capabilities or {}, + "supports_chat": model.supports_chat, + "supports_image_input": model.supports_image_input, + "supports_tools": model.supports_tools, + "supports_image_generation": model.supports_image_generation, + "capabilities_override": model.capabilities_override or {}, "billing_tier": "byok", - "litellm_provider": conn.litellm_provider, + "provider": conn.provider, "model_name": model.model_id, "auto_pin_tier": catalog.get("auto_pin_tier") or "BYOK", "quality_score": catalog.get("quality_score") or 75, @@ -357,7 +361,7 @@ def _is_preferred_premium_auto_config(cfg: dict) -> bool: return ( cfg.get("source") == "global" and _tier_of(cfg) == "premium" - and str(cfg.get("litellm_provider", "")).lower() == "azure" + and str(cfg.get("provider", "")).lower() == "azure" and str(cfg.get("model_name", "")).lower() == "gpt-5.4" ) diff --git a/surfsense_backend/app/services/global_model_catalog.py b/surfsense_backend/app/services/global_model_catalog.py index e40b3a942..ca1249497 100644 --- a/surfsense_backend/app/services/global_model_catalog.py +++ b/surfsense_backend/app/services/global_model_catalog.py @@ -18,8 +18,7 @@ def _connection_key(conn: dict[str, Any]) -> tuple[Any, ...]: # Deliberately includes api_key because two operator-owned credentials for # the same provider/base can have different quota/rate limits upstream. return ( - conn.get("protocol"), - conn.get("litellm_provider"), + conn.get("provider"), conn.get("base_url"), conn.get("api_key"), _freeze(conn.get("extra") or {}), @@ -34,16 +33,6 @@ def _freeze(value: Any) -> Any: return value -def _capabilities_for(role: str, config: dict[str, Any]) -> dict[str, bool]: - return { - "chat": role == "chat", - "vision": role == "vision" or bool(config.get("supports_image_input")), - "image_gen": role == "image_gen", - "embedding": False, - "tools": bool(config.get("supports_tools", False)), - } - - def _catalog_metadata(config: dict[str, Any]) -> dict[str, Any]: return { "billing_tier": config.get("billing_tier", "free"), @@ -72,7 +61,6 @@ def _catalog_metadata(config: dict[str, Any]) -> dict[str, Any]: def materialize_global_model_catalog( *, chat_configs: list[dict[str, Any]], - vision_configs: list[dict[str, Any]], image_configs: list[dict[str, Any]], ) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]: connections: list[dict[str, Any]] = [] @@ -109,9 +97,13 @@ def materialize_global_model_catalog( "model_id": config["model_name"], "display_name": config.get("name") or config["model_name"], "source": "MANUAL", - "capabilities": _capabilities_for(role, config), - "capabilities_declared": _capabilities_for(role, config), - "capabilities_verified": _capabilities_for(role, config), + "supports_chat": role == "chat", + "max_input_tokens": config.get("max_input_tokens"), + "supports_image_input": ( + role == "chat" and bool(config.get("supports_image_input")) + ), + "supports_tools": bool(config.get("supports_tools", False)), + "supports_image_generation": role == "image_gen", "capabilities_override": {}, "embedding_dimension": None, "enabled": True, @@ -125,10 +117,6 @@ def materialize_global_model_catalog( if cfg.get("is_auto_mode"): continue add_config(cfg, "chat") - for cfg in vision_configs: - if cfg.get("is_auto_mode"): - continue - add_config(cfg, "vision") for cfg in image_configs: if cfg.get("is_auto_mode"): continue diff --git a/surfsense_backend/app/services/llm_service.py b/surfsense_backend/app/services/llm_service.py index 86a9c8556..eadb4dbf8 100644 --- a/surfsense_backend/app/services/llm_service.py +++ b/surfsense_backend/app/services/llm_service.py @@ -15,6 +15,7 @@ from app.services.auto_model_pin_service import ( choose_auto_model_candidate, ) from app.services.llm_router_service import AUTO_MODE_ID, ChatLiteLLMRouter, is_auto_mode +from app.services.model_capabilities import has_capability from app.services.model_resolver import native_connection_from_config, to_litellm from app.services.token_tracking_service import token_tracker @@ -76,7 +77,7 @@ def _legacy_config_connection( api_version: str | None = None, ) -> tuple[str, dict]: cfg = { - "litellm_provider": provider.lower(), + "provider": provider.lower(), "model_name": model_name, "api_key": api_key, "api_base": api_base, @@ -136,12 +137,7 @@ def get_global_connection(connection_id: int) -> dict | None: def _has_capability(model: dict | Model, capability: str) -> bool: - caps = ( - model.get("capabilities", {}) - if isinstance(model, dict) - else model.capabilities or {} - ) - return bool(caps.get(capability)) + return has_capability(model, capability) def _chat_litellm_from_resolved( @@ -420,8 +416,6 @@ async def get_vision_llm( unwrapped — they don't consume premium credit (issue M). """ from app.services.quota_checked_vision_llm import QuotaCheckedVisionLLM - from app.services.vision_llm_router_service import is_vision_auto_mode - try: result = await session.execute( select(SearchSpace).where(SearchSpace.id == search_space_id) @@ -468,7 +462,7 @@ async def get_vision_llm( logger.error(f"No vision LLM configured for search space {search_space_id}") return None - if is_vision_auto_mode(config_id): + if config_id == AUTO_MODE_ID: candidates = await auto_model_candidates( session, search_space_id=search_space_id, diff --git a/surfsense_backend/app/services/model_capabilities.py b/surfsense_backend/app/services/model_capabilities.py new file mode 100644 index 000000000..fb7681f35 --- /dev/null +++ b/surfsense_backend/app/services/model_capabilities.py @@ -0,0 +1,36 @@ +"""Override-aware model capability lookup.""" + +from __future__ import annotations + +from collections.abc import Mapping +from typing import Any + +CAPABILITY_FIELDS = { + "chat": "supports_chat", + "vision": "supports_image_input", + "image_gen": "supports_image_generation", + "tools": "supports_tools", +} + + +def _get_value(model: Any, key: str) -> Any: + if isinstance(model, Mapping): + return model.get(key) + return getattr(model, key, None) + + +def has_capability(model: Any, capability: str) -> bool: + field = CAPABILITY_FIELDS.get(capability) + if field is None: + return False + + override = _get_value(model, "capabilities_override") or {} + if isinstance(override, Mapping) and field in override: + return bool(override[field]) + if isinstance(override, Mapping) and capability in override: + return bool(override[capability]) + + return bool(_get_value(model, field)) + + +__all__ = ["CAPABILITY_FIELDS", "has_capability"] diff --git a/surfsense_backend/app/services/model_connection_service.py b/surfsense_backend/app/services/model_connection_service.py index 5e5b231f9..428af736e 100644 --- a/surfsense_backend/app/services/model_connection_service.py +++ b/surfsense_backend/app/services/model_connection_service.py @@ -11,8 +11,10 @@ from typing import Any import httpx import litellm -from app.db import Connection, ConnectionProtocol, Model, ModelSource +from app.db import Connection, Model, ModelSource from app.services.model_resolver import ensure_v1, to_litellm +from app.services.openrouter_model_normalizer import normalize_openrouter_models +from app.services.provider_registry import Transport, spec_for logger = logging.getLogger(__name__) @@ -41,6 +43,16 @@ def _anthropic_headers(conn: Connection) -> dict[str, str]: return headers +def _base_url_or_default(conn: Connection) -> str | None: + if conn.base_url: + return conn.base_url.rstrip("/") + if conn.provider == "openai": + return "https://api.openai.com/v1" + if conn.provider == "anthropic": + return "https://api.anthropic.com/v1" + return spec_for(conn.provider).default_base_url + + def _docker_hint(url: str | None, exc_or_status: Any) -> str: raw = str(exc_or_status) if not url: @@ -61,32 +73,30 @@ def _docker_hint(url: str | None, exc_or_status: Any) -> str: async def verify_connection(conn: Connection) -> VerifyResult: - if not conn.base_url: + spec = spec_for(conn.provider) + base_url = _base_url_or_default(conn) + if spec.base_url_required and not base_url: return VerifyResult("UNREACHABLE", False, "Base URL is required.") - if conn.protocol == ConnectionProtocol.OLLAMA: - url = f"{conn.base_url.rstrip('/')}/api/version" - elif conn.protocol == ConnectionProtocol.OPENAI_COMPATIBLE: - url = f"{ensure_v1(conn.base_url)}/models" - elif conn.protocol == ConnectionProtocol.ANTHROPIC: - url = f"{conn.base_url.rstrip('/')}/models" + if spec.transport == Transport.OLLAMA and base_url: + url = f"{base_url.rstrip('/')}/api/version" + elif spec.discovery in {"openai_models", "openrouter"} and base_url: + url = f"{ensure_v1(base_url)}/models" + elif spec.discovery == "anthropic_models" and base_url: + url = f"{base_url.rstrip('/')}/models" else: - return VerifyResult("UNREACHABLE", False, "Unsupported connection protocol.") + return VerifyResult("OK", True, "Connection uses provider-native authentication.") try: async with httpx.AsyncClient(timeout=VERIFY_TIMEOUT_SECONDS) as client: - headers = ( - _anthropic_headers(conn) - if conn.protocol == ConnectionProtocol.ANTHROPIC - else _auth_headers(conn) - ) + headers = _anthropic_headers(conn) if spec.auth_style == "x-api-key" else _auth_headers(conn) response = await client.get(url, headers=headers) if response.status_code in (401, 403): return VerifyResult("AUTH_FAILED", False, "Authentication failed.") if response.status_code == 404: - if conn.protocol == ConnectionProtocol.OLLAMA and url.endswith("/v1/models"): + if spec.transport == Transport.OLLAMA and url.endswith("/v1/models"): message = "Ollama native API should not use /v1." - elif conn.protocol == ConnectionProtocol.OPENAI_COMPATIBLE: + elif spec.transport == Transport.OPENAI_COMPATIBLE: message = "OpenAI-compatible servers should expose /v1/models." else: message = "Endpoint returned 404." @@ -94,11 +104,11 @@ async def verify_connection(conn: Connection) -> VerifyResult: response.raise_for_status() return VerifyResult("OK", True, "Connection verified.") except httpx.ConnectError as exc: - return VerifyResult("UNREACHABLE", False, _docker_hint(conn.base_url, exc)) + return VerifyResult("UNREACHABLE", False, _docker_hint(base_url, exc)) except httpx.TimeoutException as exc: return VerifyResult("UNREACHABLE", False, f"Connection timed out: {exc}") except httpx.HTTPError as exc: - return VerifyResult("UNREACHABLE", False, _docker_hint(conn.base_url, exc)) + return VerifyResult("UNREACHABLE", False, _docker_hint(base_url, exc)) async def persist_verification(conn: Connection) -> VerifyResult: @@ -109,123 +119,193 @@ async def persist_verification(conn: Connection) -> VerifyResult: return result -def _litellm_capabilities(model_string: str, model_id: str) -> dict[str, bool]: - capabilities = { - "chat": True, - "vision": False, - "tools": False, - "image_gen": False, - "embedding": False, - } - with contextlib.suppress(Exception): - capabilities["vision"] = bool(litellm.supports_vision(model=model_string)) - with contextlib.suppress(Exception): - capabilities["tools"] = bool(litellm.supports_function_calling(model=model_string)) - try: - info = litellm.model_cost.get(model_string) or litellm.model_cost.get(model_id) or {} - mode = str(info.get("mode") or "") - capabilities["embedding"] = mode == "embedding" - capabilities["image_gen"] = mode in {"image_generation", "image_generation_model"} - except Exception: - pass - return capabilities - - def _allowlist(conn: Connection) -> set[str]: - """Per-connection model-id allowlist stored in ``extra.model_ids``. - - Empty/absent means "no restriction" (discover everything), mirroring - OpenWebUI's behaviour. A non-empty list restricts discovery to those ids — - essential for providers like OpenRouter that expose hundreds of models. - """ raw = (conn.extra or {}).get("model_ids") or [] return {str(item).strip() for item in raw if str(item).strip()} -async def _discover_openai_shaped_models(conn: Connection, base_url: str | None) -> list[dict[str, Any]]: - if not base_url: +def _litellm_info(model_string: str, model_id: str) -> dict[str, Any]: + with contextlib.suppress(Exception): + info = litellm.get_model_info(model=model_string) + if isinstance(info, dict): + return info + return litellm.model_cost.get(model_string) or litellm.model_cost.get(model_id) or {} + + +def _classify_from_litellm(model_string: str, model_id: str) -> dict[str, Any]: + info = _litellm_info(model_string, model_id) + mode = info.get("mode") + supports_image_input = False + supports_tools = False + with contextlib.suppress(Exception): + supports_image_input = bool(litellm.supports_vision(model=model_string)) + with contextlib.suppress(Exception): + supports_tools = bool(litellm.supports_function_calling(model=model_string)) + return { + "supports_chat": mode in (None, "chat", "completion", "responses"), + "max_input_tokens": info.get("max_input_tokens") or info.get("max_tokens"), + "supports_image_input": supports_image_input, + "supports_tools": supports_tools, + "supports_image_generation": mode in {"image_generation", "image_generation_model"}, + } + + +def derive_capabilities(conn: Connection, model_id: str, metadata: dict | None = None) -> dict[str, Any]: + metadata = metadata or {} + spec = spec_for(conn.provider) + model_string, _ = to_litellm(conn, model_id) + facts = _classify_from_litellm(model_string, model_id) + if spec.transport == Transport.OLLAMA: + caps = set(metadata.get("capabilities") or []) + details = metadata.get("details") or {} + facts.update( + { + "supports_chat": "embedding" not in caps, + "supports_image_input": "vision" in caps or facts["supports_image_input"], + "supports_tools": "tools" in caps or facts["supports_tools"], + "supports_image_generation": False, + "max_input_tokens": metadata.get("context_length") + or metadata.get("num_ctx") + or details.get("context_length") + or facts["max_input_tokens"], + } + ) + return facts + + +async def _discover_openai_shaped_models( + conn: Connection, base_url: str | None +) -> list[dict[str, Any]]: + resolved_base_url = base_url or _base_url_or_default(conn) + if not resolved_base_url: return [] - url = f"{ensure_v1(base_url)}/models" + url = f"{ensure_v1(resolved_base_url)}/models" async with httpx.AsyncClient(timeout=DISCOVERY_TIMEOUT_SECONDS) as client: response = await client.get(url, headers=_auth_headers(conn)) response.raise_for_status() - return [ - { - "model_id": item.get("id"), - "display_name": item.get("name") or item.get("id"), - "source": ModelSource.DISCOVERED, - "capabilities": derive_capabilities(conn, item.get("id"), item), - "metadata": item, - } - for item in response.json().get("data", []) - if item.get("id") - ] + + results: list[dict[str, Any]] = [] + for item in response.json().get("data", []): + model_id = item.get("id") + if not model_id: + continue + results.append( + { + "model_id": model_id, + "display_name": item.get("name") or model_id, + "source": ModelSource.DISCOVERED, + **derive_capabilities(conn, model_id, item), + "metadata": item, + } + ) + return results async def _discover_anthropic_models(conn: Connection) -> list[dict[str, Any]]: - if not conn.base_url: + base_url = _base_url_or_default(conn) + if not base_url: return [] - url = f"{conn.base_url.rstrip('/')}/models" + url = f"{base_url.rstrip('/')}/models" async with httpx.AsyncClient(timeout=DISCOVERY_TIMEOUT_SECONDS) as client: response = await client.get(url, headers=_anthropic_headers(conn)) response.raise_for_status() - models = response.json().get("data", []) - return [ - { - "model_id": item.get("id"), - "display_name": item.get("display_name") or item.get("id"), - "source": ModelSource.DISCOVERED, - "capabilities": derive_capabilities(conn, item.get("id"), item), - "metadata": item, - } - for item in models - if item.get("id") - ] + + results: list[dict[str, Any]] = [] + for item in response.json().get("data", []): + model_id = item.get("id") + if not model_id: + continue + results.append( + { + "model_id": model_id, + "display_name": item.get("display_name") or model_id, + "source": ModelSource.DISCOVERED, + **derive_capabilities(conn, model_id, item), + "metadata": item, + } + ) + return results -def derive_capabilities(conn: Connection, model_id: str, metadata: dict | None = None) -> dict[str, bool]: - metadata = metadata or {} - if conn.protocol == ConnectionProtocol.OLLAMA: - caps = metadata.get("capabilities") or [] - capabilities = { - "chat": True, - "vision": "vision" in caps, - "tools": False, - "image_gen": False, - "embedding": "embedding" in caps, - } - return capabilities +async def _ollama_tags_then_show(conn: Connection) -> list[dict[str, Any]]: + if not conn.base_url: + return [] - model_string, _ = to_litellm(conn, model_id) - return _litellm_capabilities(model_string, model_id) + base_url = conn.base_url.rstrip("/") + async with httpx.AsyncClient(timeout=DISCOVERY_TIMEOUT_SECONDS) as client: + response = await client.get(f"{base_url}/api/tags", headers=_auth_headers(conn)) + response.raise_for_status() + models = response.json().get("models", []) + results: list[dict[str, Any]] = [] + for item in models: + model_id = item.get("model") or item.get("name") + if not model_id: + continue + metadata = dict(item) + with contextlib.suppress(Exception): + show_response = await client.post( + f"{base_url}/api/show", + json={"model": model_id}, + headers=_auth_headers(conn), + ) + show_response.raise_for_status() + metadata.update(show_response.json()) + results.append( + { + "model_id": model_id, + "display_name": item.get("name") or model_id, + "source": ModelSource.DISCOVERED, + **derive_capabilities(conn, model_id, metadata), + "metadata": metadata, + } + ) + return results + + +async def _openrouter_models(conn: Connection) -> list[dict[str, Any]]: + base_url = _base_url_or_default(conn) or "https://openrouter.ai/api/v1" + async with httpx.AsyncClient(timeout=DISCOVERY_TIMEOUT_SECONDS) as client: + response = await client.get(f"{ensure_v1(base_url)}/models", headers=_auth_headers(conn)) + response.raise_for_status() + return normalize_openrouter_models(response.json().get("data", [])) + + +def _litellm_static_models(conn: Connection) -> list[dict[str, Any]]: + provider = conn.provider + prefix = spec_for(provider).litellm_prefix or provider + results: list[dict[str, Any]] = [] + for model_string, metadata in litellm.model_cost.items(): + if not isinstance(model_string, str) or not model_string.startswith(f"{prefix}/"): + continue + model_id = model_string.split("/", 1)[1] + results.append( + { + "model_id": model_id, + "display_name": metadata.get("display_name") or model_id, + "source": ModelSource.DISCOVERED, + **_classify_from_litellm(model_string, model_id), + "metadata": metadata, + } + ) + return results async def discover_models(conn: Connection) -> list[dict[str, Any]]: allowlist = _allowlist(conn) + spec = spec_for(conn.provider) - if conn.protocol == ConnectionProtocol.OLLAMA: - url = f"{conn.base_url.rstrip('/')}/api/tags" - async with httpx.AsyncClient(timeout=DISCOVERY_TIMEOUT_SECONDS) as client: - response = await client.get(url, headers=_auth_headers(conn)) - response.raise_for_status() - models = response.json().get("models", []) - results = [ - { - "model_id": item.get("model") or item.get("name"), - "display_name": item.get("name") or item.get("model"), - "source": ModelSource.DISCOVERED, - "capabilities": derive_capabilities(conn, item.get("model") or item.get("name"), item), - "metadata": item, - } - for item in models - if item.get("model") or item.get("name") - ] - elif conn.protocol == ConnectionProtocol.OPENAI_COMPATIBLE: - results = await _discover_openai_shaped_models(conn, conn.base_url) - elif conn.protocol == ConnectionProtocol.ANTHROPIC: + if spec.discovery == "ollama": + results = await _ollama_tags_then_show(conn) + elif spec.discovery == "openrouter": + results = await _openrouter_models(conn) + elif spec.discovery == "anthropic_models": results = await _discover_anthropic_models(conn) + elif spec.discovery == "openai_models": + results = await _discover_openai_shaped_models(conn, conn.base_url) + elif spec.discovery == "static": + results = _litellm_static_models(conn) else: results = [] @@ -246,10 +326,7 @@ async def test_model(conn: Connection, model: Model) -> VerifyResult: except Exception as exc: return VerifyResult("UNREACHABLE", False, str(exc)) - model.capabilities_verified = { - **(model.capabilities_verified or {}), - "chat": True, - } + model.supports_chat = True return VerifyResult("OK", True, "Model test succeeded.") diff --git a/surfsense_backend/app/services/model_list_service.py b/surfsense_backend/app/services/model_list_service.py index 1ef0b0c90..0761d7e4f 100644 --- a/surfsense_backend/app/services/model_list_service.py +++ b/surfsense_backend/app/services/model_list_service.py @@ -12,6 +12,8 @@ from pathlib import Path import httpx +from app.services.openrouter_model_normalizer import normalize_openrouter_models + logger = logging.getLogger(__name__) OPENROUTER_API_URL = "https://openrouter.ai/api/v1/models" @@ -121,26 +123,13 @@ def _process_models(raw_models: list[dict]) -> list[dict]: """ processed: list[dict] = [] - for model in raw_models: - model_id: str = model.get("id", "") - name: str = model.get("name", "") - context_length = model.get("context_length") - + for normalized in normalize_openrouter_models(raw_models): + model_id: str = normalized["model_id"] + name: str = normalized.get("display_name") or model_id + context_length = normalized.get("max_input_tokens") if "/" not in model_id: continue - if not _is_text_output_model(model): - continue - - if not _supports_tool_calling(model): - continue - - if not _has_sufficient_context(model): - continue - - if not _is_allowed_model(model): - continue - provider_slug, model_name = model_id.split("/", 1) context_window = _format_context_length(context_length) diff --git a/surfsense_backend/app/services/model_resolver.py b/surfsense_backend/app/services/model_resolver.py index ffa77a9a2..ae6fd2877 100644 --- a/surfsense_backend/app/services/model_resolver.py +++ b/surfsense_backend/app/services/model_resolver.py @@ -12,9 +12,7 @@ from typing import TYPE_CHECKING, Any if TYPE_CHECKING: from app.db import Connection -PROTOCOL_OLLAMA = "OLLAMA" -PROTOCOL_OPENAI_COMPATIBLE = "OPENAI_COMPATIBLE" -PROTOCOL_ANTHROPIC = "ANTHROPIC" +from app.services.provider_registry import Transport, spec_for def ensure_v1(base_url: str | None) -> str | None: @@ -32,47 +30,25 @@ def _conn_value(conn: Connection | Mapping[str, Any], key: str) -> Any: return getattr(conn, key) -def _protocol_value(protocol: Any) -> str: - return getattr(protocol, "value", str(protocol)) - - -def default_litellm_provider(protocol: Any) -> str: - protocol_value = _protocol_value(protocol) - defaults = { - PROTOCOL_OLLAMA: "ollama_chat", - PROTOCOL_ANTHROPIC: "anthropic", - PROTOCOL_OPENAI_COMPATIBLE: "openai", - } - return defaults.get(protocol_value, "openai") - - -def _execution_api_base(protocol: str, base_url: str | None) -> str | None: - del protocol - if not base_url: - return None - return base_url.rstrip("/") - - def to_litellm( conn: Connection | Mapping[str, Any], model_id: str, ) -> tuple[str, dict[str, Any]]: """Return ``(model_string, litellm_kwargs)`` for any model role.""" - protocol = _protocol_value(_conn_value(conn, "protocol")) + provider = _conn_value(conn, "provider") base_url = _conn_value(conn, "base_url") api_key = _conn_value(conn, "api_key") - litellm_provider = ( - _conn_value(conn, "litellm_provider") or default_litellm_provider(protocol) - ) extra = _conn_value(conn, "extra") or {} + spec = spec_for(provider) kwargs: dict[str, Any] = {} if api_key: kwargs["api_key"] = api_key - model_string = f"{litellm_provider}/{model_id}" if litellm_provider else model_id - api_base = _execution_api_base(protocol, base_url) - if api_base: + prefix = spec.litellm_prefix or str(provider) + model_string = f"{prefix}/{model_id}" if prefix else model_id + if base_url: + api_base = ensure_v1(base_url) if spec.transport == Transport.OPENAI_COMPATIBLE else base_url.rstrip("/") kwargs["api_base"] = api_base if api_version := extra.get("api_version"): @@ -84,11 +60,11 @@ def to_litellm( def native_connection_from_config(config: Mapping[str, Any]) -> dict[str, Any]: """Build an in-memory connection mapping from a global config.""" - protocol = str(config.get("protocol") or PROTOCOL_OPENAI_COMPATIBLE) - litellm_provider = str( - config.get("litellm_provider") + provider = str( + config.get("provider") + or config.get("litellm_provider") or config.get("custom_provider") - or default_litellm_provider(protocol) + or "openai" ) extra: dict[str, Any] = { "litellm_params": config.get("litellm_params") or {}, @@ -96,8 +72,7 @@ def native_connection_from_config(config: Mapping[str, Any]) -> dict[str, Any]: if config.get("api_version"): extra["api_version"] = config.get("api_version") return { - "protocol": protocol, - "litellm_provider": litellm_provider, + "provider": provider, "base_url": config.get("api_base") or None, "api_key": config.get("api_key") or None, "extra": extra, @@ -105,7 +80,6 @@ def native_connection_from_config(config: Mapping[str, Any]) -> dict[str, Any]: __all__ = [ - "default_litellm_provider", "ensure_v1", "native_connection_from_config", "to_litellm", diff --git a/surfsense_backend/app/services/openrouter_integration_service.py b/surfsense_backend/app/services/openrouter_integration_service.py index 6996f0fde..fbb70eb5a 100644 --- a/surfsense_backend/app/services/openrouter_integration_service.py +++ b/surfsense_backend/app/services/openrouter_integration_service.py @@ -29,6 +29,13 @@ from app.services.quality_score import ( aggregate_health, static_score_or, ) +from app.services.openrouter_model_normalizer import ( + is_allowed_model as _shared_is_allowed_model, + is_compatible_provider as _shared_is_compatible_provider, + is_openrouter_image_model, + normalize_openrouter_models, + supports_image_input, +) logger = logging.getLogger(__name__) @@ -292,24 +299,16 @@ def _generate_configs( use_default: bool = settings.get("use_default_system_instructions", True) citations_enabled: bool = settings.get("citations_enabled", True) - text_models = [ - m - for m in raw_models - if _is_text_output_model(m) - and _supports_tool_calling(m) - and _has_sufficient_context(m) - and _is_compatible_provider(m) - and _is_allowed_model(m) - and "/" in m.get("id", "") - ] + text_models = normalize_openrouter_models(raw_models) configs: list[dict] = [] taken: set[int] = set() now_ts = int(time.time()) - for model in text_models: - model_id: str = model["id"] - name: str = model.get("name", model_id) + for normalized in text_models: + model = normalized.get("metadata") or {} + model_id: str = normalized["model_id"] + name: str = normalized.get("display_name") or model_id tier = _openrouter_tier(model) static_q = static_score_or(model, now_ts=now_ts) @@ -323,7 +322,7 @@ def _generate_configs( "seo_enabled": seo_enabled, "seo_slug": None, "quota_reserve_tokens": quota_reserve_tokens, - "litellm_provider": "openrouter", + "provider": "openrouter", "model_name": model_id, "api_key": api_key, "api_base": "https://openrouter.ai/api/v1", @@ -345,7 +344,7 @@ def _generate_configs( # ``stream_new_chat`` as a fail-fast safety net before the # OpenRouter request would otherwise 404 with # ``"No endpoints found that support image input"``. - "supports_image_input": _supports_image_input(model), + "supports_image_input": bool(normalized.get("supports_image_input")), _OPENROUTER_DYNAMIC_MARKER: True, # Auto (Fastest) ranking metadata. ``quality_score`` is initialised # to the static score and gets re-blended with health on the next @@ -403,10 +402,7 @@ def _generate_image_gen_configs( image_models = [ m for m in raw_models - if _is_image_output_model(m) - and _is_compatible_provider(m) - and _is_allowed_model(m) - and "/" in m.get("id", "") + if is_openrouter_image_model(m) ] configs: list[dict] = [] @@ -420,7 +416,7 @@ def _generate_image_gen_configs( "id": _stable_config_id(model_id, id_offset, taken), "name": name, "description": f"{name} via OpenRouter (image generation)", - "litellm_provider": "openrouter", + "provider": "openrouter", "model_name": model_id, "api_key": api_key, "api_base": "https://openrouter.ai/api/v1", @@ -468,9 +464,9 @@ def _generate_vision_llm_configs( vision_models = [ m for m in raw_models - if _is_vision_input_model(m) - and _is_compatible_provider(m) - and _is_allowed_model(m) + if supports_image_input(m) + and _shared_is_compatible_provider(m) + and _shared_is_allowed_model(m) and "/" in m.get("id", "") ] @@ -499,7 +495,7 @@ def _generate_vision_llm_configs( "id": _stable_config_id(model_id, id_offset, taken), "name": name, "description": f"{name} via OpenRouter (vision)", - "litellm_provider": "openrouter", + "provider": "openrouter", "model_name": model_id, "api_key": api_key, "api_base": "https://openrouter.ai/api/v1", @@ -544,11 +540,9 @@ class OpenRouterIntegrationService: # Cached raw catalogue from the most recent fetch. Image / vision # emitters reuse this to avoid a second network call per surface. self._raw_models: list[dict] = [] - # Image / vision config caches (only populated when the matching - # opt-in flag is true on initialize). Refreshed in lockstep with - # the chat catalogue. + # Image config cache (only populated when the matching opt-in flag is + # true on initialize). Refreshed in lockstep with the chat catalogue. self._image_configs: list[dict] = [] - self._vision_configs: list[dict] = [] @classmethod def get_instance(cls) -> "OpenRouterIntegrationService": @@ -583,7 +577,7 @@ class OpenRouterIntegrationService: self._configs_by_id = {c["id"]: c for c in self._configs} self._raw_pricing = _extract_raw_pricing(raw_models) - # Populate image / vision caches when their opt-in flag is set. + # Populate image cache when its opt-in flag is set. # Empty otherwise so the accessors return [] without re-running # filters every refresh. if settings.get("image_generation_enabled"): @@ -595,15 +589,6 @@ class OpenRouterIntegrationService: else: self._image_configs = [] - if settings.get("vision_enabled"): - self._vision_configs = _generate_vision_llm_configs(raw_models, settings) - logger.info( - "OpenRouter integration: vision LLM emission ON (%d models)", - len(self._vision_configs), - ) - else: - self._vision_configs = [] - self._initialized = True tier_counts = self._tier_counts(self._configs) @@ -657,9 +642,9 @@ class OpenRouterIntegrationService: self._configs = new_configs self._configs_by_id = new_by_id - # Image / vision lists are atomic-swapped the same way: filter out + # Image list is atomic-swapped the same way: filter out # the previous dynamic entries from the live config list and append - # the freshly generated ones. No-ops when the opt-in flag is off. + # the freshly generated ones. No-op when the opt-in flag is off. if self._settings.get("image_generation_enabled"): new_image = _generate_image_gen_configs(raw_models, self._settings) static_image = [ @@ -670,16 +655,6 @@ class OpenRouterIntegrationService: app_config.GLOBAL_IMAGE_GEN_CONFIGS = static_image + new_image self._image_configs = new_image - if self._settings.get("vision_enabled"): - new_vision = _generate_vision_llm_configs(raw_models, self._settings) - static_vision = [ - c - for c in app_config.GLOBAL_VISION_LLM_CONFIGS - if not c.get(_OPENROUTER_DYNAMIC_MARKER) - ] - app_config.GLOBAL_VISION_LLM_CONFIGS = static_vision + new_vision - self._vision_configs = new_vision - # Catalogue churn invalidates per-config "recently healthy" credit # earned by the previous turn's preflight. Drop the whole table so # the next turn re-probes against the freshly loaded configs. @@ -701,7 +676,7 @@ class OpenRouterIntegrationService: ) # Re-blend health scores against the freshly fetched catalogue. Also - # re-stamps health for any YAML-curated cfg with litellm_provider=openrouter + # re-stamps health for any YAML-curated cfg with provider=openrouter # so a hand-picked dead OR model is gated like a dynamic one. await self._enrich_health_safely(static_configs + new_configs, log_summary=True) @@ -778,7 +753,7 @@ class OpenRouterIntegrationService: the entire previous cycle's cache for this run. """ or_cfgs = [ - c for c in configs if str(c.get("litellm_provider", "")).lower() == "openrouter" + c for c in configs if str(c.get("provider", "")).lower() == "openrouter" ] if not or_cfgs: return @@ -959,17 +934,6 @@ class OpenRouterIntegrationService: """ return list(self._image_configs) - def get_vision_llm_configs(self) -> list[dict]: - """Return the dynamic OpenRouter vision-LLM configs (empty list - when the ``vision_enabled`` flag is off). - - Each entry exposes ``input_cost_per_token`` / ``output_cost_per_token`` - so ``pricing_registration`` can teach LiteLLM the cost of these - models the same way it does for chat — which keeps the billable - wrapper able to debit accurate micro-USD on a vision call. - """ - return list(self._vision_configs) - def get_raw_pricing(self) -> dict[str, dict[str, str]]: """Return the cached raw OpenRouter pricing map. diff --git a/surfsense_backend/app/services/openrouter_model_normalizer.py b/surfsense_backend/app/services/openrouter_model_normalizer.py new file mode 100644 index 000000000..0b646f933 --- /dev/null +++ b/surfsense_backend/app/services/openrouter_model_normalizer.py @@ -0,0 +1,121 @@ +"""Shared OpenRouter model normalization. + +OpenRouter metadata is richer than generic OpenAI-compatible ``/models`` +responses. Keep all OpenRouter filtering and capability extraction here so +GLOBAL catalogue generation and BYOK discovery agree. +""" + +from __future__ import annotations + +from typing import Any + +from app.db import ModelSource + +MIN_CONTEXT_LENGTH = 100_000 + +EXCLUDED_PROVIDER_SLUGS = {"amazon"} +EXCLUDED_MODEL_IDS: set[str] = { + "openai/gpt-4-1106-preview", + "openai/gpt-4-turbo-preview", + "openai/gpt-4o:extended", + "arcee-ai/virtuoso-large", + "openai/o3-deep-research", + "openai/o4-mini-deep-research", + "openrouter/free", +} +EXCLUDED_MODEL_SUFFIXES: tuple[str, ...] = ("-deep-research",) + + +def is_text_output_model(model: dict[str, Any]) -> bool: + output_mods = model.get("architecture", {}).get("output_modalities", []) + return output_mods == ["text"] + + +def is_image_output_model(model: dict[str, Any]) -> bool: + output_mods = model.get("architecture", {}).get("output_modalities", []) or [] + return "image" in output_mods + + +def supports_image_input(model: dict[str, Any]) -> bool: + input_mods = model.get("architecture", {}).get("input_modalities", []) or [] + return "image" in input_mods + + +def supports_tool_calling(model: dict[str, Any]) -> bool: + supported = model.get("supported_parameters") or [] + return "tools" in supported + + +def has_sufficient_context(model: dict[str, Any]) -> bool: + return int(model.get("context_length") or 0) >= MIN_CONTEXT_LENGTH + + +def is_compatible_provider(model: dict[str, Any]) -> bool: + model_id = str(model.get("id") or "") + slug = model_id.split("/", 1)[0] if "/" in model_id else "" + return slug not in EXCLUDED_PROVIDER_SLUGS + + +def is_allowed_model(model: dict[str, Any]) -> bool: + model_id = str(model.get("id") or "") + if model_id in EXCLUDED_MODEL_IDS: + return False + base_id = model_id.split(":")[0] + return not base_id.endswith(EXCLUDED_MODEL_SUFFIXES) + + +def is_openrouter_chat_model(model: dict[str, Any]) -> bool: + return ( + "/" in str(model.get("id") or "") + and is_text_output_model(model) + and supports_tool_calling(model) + and has_sufficient_context(model) + and is_compatible_provider(model) + and is_allowed_model(model) + ) + + +def is_openrouter_image_model(model: dict[str, Any]) -> bool: + return ( + "/" in str(model.get("id") or "") + and is_image_output_model(model) + and is_compatible_provider(model) + and is_allowed_model(model) + ) + + +def normalize_openrouter_models(raw_models: list[dict[str, Any]]) -> list[dict[str, Any]]: + normalized: list[dict[str, Any]] = [] + for model in raw_models: + if not is_openrouter_chat_model(model): + continue + model_id = str(model.get("id") or "") + normalized.append( + { + "model_id": model_id, + "display_name": model.get("name") or model_id, + "source": ModelSource.DISCOVERED, + "supports_chat": True, + "max_input_tokens": model.get("context_length"), + "supports_image_input": supports_image_input(model), + "supports_tools": supports_tool_calling(model), + "supports_image_generation": False, + "metadata": model, + } + ) + return normalized + + +__all__ = [ + "MIN_CONTEXT_LENGTH", + "has_sufficient_context", + "is_allowed_model", + "is_compatible_provider", + "is_image_output_model", + "is_openrouter_chat_model", + "is_openrouter_image_model", + "is_text_output_model", + "normalize_openrouter_models", + "supports_image_input", + "supports_tool_calling", +] diff --git a/surfsense_backend/app/services/pricing_registration.py b/surfsense_backend/app/services/pricing_registration.py index 6b99fe723..9e4e3b552 100644 --- a/surfsense_backend/app/services/pricing_registration.py +++ b/surfsense_backend/app/services/pricing_registration.py @@ -143,7 +143,7 @@ def _register_chat_shape_configs( sample_keys: list[str] = [] for cfg in configs: - provider = str(cfg.get("litellm_provider") or "").lower() + provider = str(cfg.get("provider") or cfg.get("litellm_provider") or "").lower() model_name = str(cfg.get("model_name") or "").strip() litellm_params = cfg.get("litellm_params") or {} base_model = str(litellm_params.get("base_model") or model_name).strip() @@ -216,9 +216,8 @@ def _register_chat_shape_configs( def register_pricing_from_global_configs() -> None: """Register pricing for every known LLM deployment with LiteLLM. - Walks ``config.GLOBAL_LLM_CONFIGS`` *and* ``config.GLOBAL_VISION_LLM_CONFIGS`` - so vision calls (during indexing) can resolve cost the same way chat - calls do — namely: + Walks ``config.GLOBAL_LLM_CONFIGS`` so chat and vision calls can resolve + cost from the same chat-shaped deployment configs: 1. ``OPENROUTER``: pulls the cached raw pricing from ``OpenRouterIntegrationService`` (populated during its own @@ -245,10 +244,7 @@ def register_pricing_from_global_configs() -> None: from app.config import config as app_config chat_configs: list[dict] = list(getattr(app_config, "GLOBAL_LLM_CONFIGS", []) or []) - vision_configs: list[dict] = list( - getattr(app_config, "GLOBAL_VISION_LLM_CONFIGS", []) or [] - ) - if not chat_configs and not vision_configs: + if not chat_configs: logger.info("[PricingRegistration] no global configs to register") return @@ -267,7 +263,3 @@ def register_pricing_from_global_configs() -> None: if chat_configs: _register_chat_shape_configs(chat_configs, or_pricing=or_pricing, label="chat") - if vision_configs: - _register_chat_shape_configs( - vision_configs, or_pricing=or_pricing, label="vision" - ) diff --git a/surfsense_backend/app/services/provider_capabilities.py b/surfsense_backend/app/services/provider_capabilities.py index 9521ef7a4..fae283ab6 100644 --- a/surfsense_backend/app/services/provider_capabilities.py +++ b/surfsense_backend/app/services/provider_capabilities.py @@ -51,7 +51,7 @@ logger = logging.getLogger(__name__) def _candidate_model_strings( *, - litellm_provider: str | None, + provider: str | None, model_name: str | None, base_model: str | None, custom_provider: str | None, @@ -78,7 +78,7 @@ def _candidate_model_strings( seen.add(key) candidates.append(key) - provider_prefix = custom_provider or litellm_provider + provider_prefix = custom_provider or provider primary_model = base_model or model_name bare_model = model_name @@ -113,7 +113,7 @@ def _candidate_model_strings( def derive_supports_image_input( *, - litellm_provider: str | None = None, + provider: str | None = None, model_name: str | None = None, base_model: str | None = None, custom_provider: str | None = None, @@ -147,7 +147,7 @@ def derive_supports_image_input( return False for model_string, custom_llm_provider in _candidate_model_strings( - litellm_provider=litellm_provider, + provider=provider, model_name=model_name, base_model=base_model, custom_provider=custom_provider, @@ -172,7 +172,7 @@ def derive_supports_image_input( def is_known_text_only_chat_model( *, - litellm_provider: str | None = None, + provider: str | None = None, model_name: str | None = None, base_model: str | None = None, custom_provider: str | None = None, @@ -193,7 +193,7 @@ def is_known_text_only_chat_model( leads to the regression we're fixing here. """ for model_string, custom_llm_provider in _candidate_model_strings( - litellm_provider=litellm_provider, + provider=provider, model_name=model_name, base_model=base_model, custom_provider=custom_provider, diff --git a/surfsense_backend/app/services/provider_registry.py b/surfsense_backend/app/services/provider_registry.py new file mode 100644 index 000000000..871769f11 --- /dev/null +++ b/surfsense_backend/app/services/provider_registry.py @@ -0,0 +1,98 @@ +"""Provider registry for model connections. + +The provider string is the single public identity axis. This registry only +describes providers whose behavior differs from LiteLLM's native default. +""" + +from __future__ import annotations + +from dataclasses import dataclass +from enum import StrEnum +from typing import Literal + + +class Transport(StrEnum): + NATIVE = "NATIVE" + OPENAI_COMPATIBLE = "OPENAI_COMPATIBLE" + OLLAMA = "OLLAMA" + + +DiscoveryKind = Literal[ + "ollama", + "openai_models", + "anthropic_models", + "openrouter", + "static", + "none", +] + +AuthStyle = Literal["bearer", "x-api-key", "none", "native"] + + +@dataclass(frozen=True) +class ProviderSpec: + transport: Transport + litellm_prefix: str | None + discovery: DiscoveryKind + default_base_url: str | None + base_url_required: bool + auth_style: AuthStyle + + +REGISTRY: dict[str, ProviderSpec] = { + "openai": ProviderSpec( + Transport.NATIVE, "openai", "openai_models", None, False, "bearer" + ), + "anthropic": ProviderSpec( + Transport.NATIVE, "anthropic", "anthropic_models", None, False, "x-api-key" + ), + "azure": ProviderSpec(Transport.NATIVE, "azure", "static", None, True, "native"), + "vertex_ai": ProviderSpec( + Transport.NATIVE, "vertex_ai", "static", None, False, "native" + ), + "bedrock": ProviderSpec( + Transport.NATIVE, "bedrock", "static", None, False, "native" + ), + "openrouter": ProviderSpec( + Transport.OPENAI_COMPATIBLE, + "openrouter", + "openrouter", + "https://openrouter.ai/api/v1", + False, + "bearer", + ), + "openai_compatible": ProviderSpec( + Transport.OPENAI_COMPATIBLE, + "openai", + "openai_models", + None, + True, + "bearer", + ), + "lm_studio": ProviderSpec( + Transport.OPENAI_COMPATIBLE, + "openai", + "openai_models", + "http://localhost:1234/v1", + True, + "bearer", + ), + "ollama_chat": ProviderSpec( + Transport.OLLAMA, + "ollama_chat", + "ollama", + "http://localhost:11434", + True, + "none", + ), +} + + +def spec_for(provider: str | None) -> ProviderSpec: + provider_key = (provider or "").strip() + return REGISTRY.get(provider_key) or ProviderSpec( + Transport.NATIVE, provider_key or "openai", "static", None, False, "native" + ) + + +__all__ = ["REGISTRY", "ProviderSpec", "Transport", "spec_for"] diff --git a/surfsense_backend/app/services/quality_score.py b/surfsense_backend/app/services/quality_score.py index 95484439b..9cc9c21ac 100644 --- a/surfsense_backend/app/services/quality_score.py +++ b/surfsense_backend/app/services/quality_score.py @@ -273,7 +273,7 @@ def static_score_yaml(cfg: dict) -> int: listed this model. Pricing / context fall through to lazy ``litellm`` lookups; failures are silent (we just lose those sub-points). """ - provider = str(cfg.get("litellm_provider", "")).lower() + provider = str(cfg.get("provider") or cfg.get("litellm_provider") or "").lower() base = PROVIDER_PRESTIGE_YAML.get(provider, 15) model_name = cfg.get("model_name") or "" diff --git a/surfsense_backend/app/tasks/chat/streaming/flows/new_chat/llm_capability.py b/surfsense_backend/app/tasks/chat/streaming/flows/new_chat/llm_capability.py index f6fcf75d7..69b9f4ab8 100644 --- a/surfsense_backend/app/tasks/chat/streaming/flows/new_chat/llm_capability.py +++ b/surfsense_backend/app/tasks/chat/streaming/flows/new_chat/llm_capability.py @@ -40,7 +40,7 @@ def check_image_input_capability( else None ) if not is_known_text_only_chat_model( - litellm_provider=agent_config.provider, + provider=agent_config.provider, model_name=agent_config.model_name, base_model=agent_base_model, custom_provider=agent_config.custom_provider, diff --git a/surfsense_backend/app/tasks/chat/streaming/flows/shared/llm_bundle.py b/surfsense_backend/app/tasks/chat/streaming/flows/shared/llm_bundle.py index f6870f5fa..cfd50950e 100644 --- a/surfsense_backend/app/tasks/chat/streaming/flows/shared/llm_bundle.py +++ b/surfsense_backend/app/tasks/chat/streaming/flows/shared/llm_bundle.py @@ -22,6 +22,7 @@ from app.agents.chat.runtime.llm_config import ( ) from app.config import config from app.db import Model, SearchSpace +from app.services.model_capabilities import has_capability from app.services.model_resolver import to_litellm @@ -96,7 +97,7 @@ async def load_llm_bundle( model_id=config_id, search_space=search_space, ) - if not model or not (model.capabilities or {}).get("chat"): + if not model or not has_capability(model, "chat"): return ( None, None, @@ -106,12 +107,12 @@ async def load_llm_bundle( agent_config = _agent_config_from_resolved( config_id=config_id, config_name=model.display_name or model.model_id, - provider=model.connection.litellm_provider or "", + provider=model.connection.provider or "", model_name=model.model_id, api_key=model.connection.api_key, api_base=model.connection.base_url, litellm_params=(model.connection.extra or {}).get("litellm_params"), - supports_image_input=bool((model.capabilities or {}).get("vision")), + supports_image_input=has_capability(model, "vision"), billing_tier="free", ) return ( @@ -121,7 +122,7 @@ async def load_llm_bundle( ) global_model = next((m for m in config.GLOBAL_MODELS if m.get("id") == config_id), None) - if not global_model or not (global_model.get("capabilities") or {}).get("chat"): + if not global_model or not has_capability(global_model, "chat"): return None, None, f"Failed to load global chat model with id {config_id}" global_connection = next( ( @@ -137,12 +138,12 @@ async def load_llm_bundle( agent_config = _agent_config_from_resolved( config_id=config_id, config_name=global_model.get("display_name") or global_model.get("model_id"), - provider=global_connection.get("litellm_provider") or "", + provider=global_connection.get("provider") or "", model_name=global_model["model_id"], api_key=global_connection.get("api_key"), api_base=global_connection.get("base_url"), litellm_params=(global_connection.get("extra") or {}).get("litellm_params"), - supports_image_input=bool((global_model.get("capabilities") or {}).get("vision")), + supports_image_input=has_capability(global_model, "vision"), billing_tier=str(global_model.get("billing_tier", "free")).lower(), ) return ( From 3dd54230e72a26027dafa96038aa762aa30cbb06 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Fri, 12 Jun 2026 02:17:37 +0530 Subject: [PATCH 23/59] fix(chat): normalize provider-safe message history --- .../app/tasks/chat/llm_history_normalizer.py | 89 +++++++++++++++++++ .../tasks/chat/message_parts_normalizer.py | 86 ++++++++++++++++++ .../tasks/chat/streaming/agent/event_loop.py | 6 ++ .../flows/shared/assistant_finalize.py | 5 ++ .../chat/streaming/shared/stream_result.py | 4 + surfsense_backend/app/utils/content_utils.py | 28 ++++-- .../chat/runtime/test_llm_config_sanitizer.py | 40 +++++++++ .../tasks/chat/test_llm_history_normalizer.py | 62 +++++++++++++ .../chat/test_message_parts_normalizer.py | 68 ++++++++++++++ 9 files changed, 382 insertions(+), 6 deletions(-) create mode 100644 surfsense_backend/app/tasks/chat/llm_history_normalizer.py create mode 100644 surfsense_backend/app/tasks/chat/message_parts_normalizer.py create mode 100644 surfsense_backend/tests/unit/agents/chat/runtime/test_llm_config_sanitizer.py create mode 100644 surfsense_backend/tests/unit/tasks/chat/test_llm_history_normalizer.py create mode 100644 surfsense_backend/tests/unit/tasks/chat/test_message_parts_normalizer.py diff --git a/surfsense_backend/app/tasks/chat/llm_history_normalizer.py b/surfsense_backend/app/tasks/chat/llm_history_normalizer.py new file mode 100644 index 000000000..e3470d06b --- /dev/null +++ b/surfsense_backend/app/tasks/chat/llm_history_normalizer.py @@ -0,0 +1,89 @@ +"""Convert persisted chat content into provider-safe LangChain history. + +Assistant UI parts are a UI/storage shape, not an LLM prompt shape. This module +extracts only model-safe content before prior turns are replayed to a provider. +""" + +from __future__ import annotations + +from typing import Any + +_USER_CONTENT_TYPES = {"text", "image", "image_url"} + + +def _text_from_block(block: dict[str, Any]) -> str: + value = block.get("text") or block.get("content") or "" + return value if isinstance(value, str) else "" + + +def assistant_content_to_llm_text(content: Any) -> str: + """Return visible assistant text, dropping reasoning/UI/provider blocks.""" + if isinstance(content, str): + return content + if isinstance(content, dict): + return _text_from_block(content) + if not isinstance(content, list): + return "" + + text_chunks: list[str] = [] + for block in content: + if isinstance(block, str): + if block: + text_chunks.append(block) + continue + if not isinstance(block, dict): + continue + if block.get("type") == "text": + text = _text_from_block(block) + if text: + text_chunks.append(text) + return "\n".join(text_chunks) + + +def user_content_to_llm_content( + content: Any, + *, + allow_images: bool = True, +) -> str | list[dict[str, Any]]: + """Return provider-safe user text/image content for LangChain.""" + if isinstance(content, str): + return content + if isinstance(content, dict): + return _text_from_block(content) + if not isinstance(content, list): + return "" + + parts: list[dict[str, Any]] = [] + text_chunks: list[str] = [] + for block in content: + if isinstance(block, str): + if block: + text_chunks.append(block) + continue + if not isinstance(block, dict): + continue + block_type = block.get("type") + if block_type not in _USER_CONTENT_TYPES: + continue + if block_type == "text": + text = _text_from_block(block) + if text: + parts.append({"type": "text", "text": text}) + text_chunks.append(text) + elif allow_images and block_type == "image": + image = block.get("image") + if isinstance(image, str) and image.startswith("data:"): + parts.append({"type": "image_url", "image_url": {"url": image}}) + elif allow_images and block_type == "image_url": + image_url = block.get("image_url") + if isinstance(image_url, dict): + url = image_url.get("url") + if isinstance(url, str) and url.startswith("data:"): + parts.append({"type": "image_url", "image_url": {"url": url}}) + elif isinstance(image_url, str) and image_url.startswith("data:"): + parts.append({"type": "image_url", "image_url": {"url": image_url}}) + + if allow_images and any(part.get("type") == "image_url" for part in parts): + return parts + return "\n".join(text_chunks) + diff --git a/surfsense_backend/app/tasks/chat/message_parts_normalizer.py b/surfsense_backend/app/tasks/chat/message_parts_normalizer.py new file mode 100644 index 000000000..953282e9f --- /dev/null +++ b/surfsense_backend/app/tasks/chat/message_parts_normalizer.py @@ -0,0 +1,86 @@ +"""Normalize final LangChain assistant messages into assistant-ui parts. + +Live streaming remains the primary source for rich, incremental UI state. +This module is only used after the graph has finished so refresh persistence +does not depend on provider-specific streaming chunk shapes. +""" + +from __future__ import annotations + +from collections.abc import Iterable +from typing import Any + +from langchain_core.messages import AIMessage + + +def _text_from_content(content: Any) -> str: + if isinstance(content, str): + return content + if not isinstance(content, list): + return "" + + text_parts: list[str] = [] + for block in content: + if not isinstance(block, dict): + continue + if block.get("type") != "text": + continue + value = block.get("text") or block.get("content") or "" + if isinstance(value, str) and value: + text_parts.append(value) + return "".join(text_parts) + + +def normalize_ai_message_to_parts(message: AIMessage | Any | None) -> list[dict[str, Any]]: + """Return user-visible assistant-ui parts for a final AI message. + + We intentionally do not backfill provider ``thinking`` / + ``reasoning_content`` blocks here. If reasoning streamed live, the + ``AssistantContentBuilder`` already captured it. If it only exists in the + final model payload, persisting it retroactively could expose content the + UI never showed during the turn. + """ + if message is None: + return [] + + text = _text_from_content(getattr(message, "content", None)).strip() + if not text: + return [] + return [{"type": "text", "text": text}] + + +def last_ai_message(messages: Iterable[Any] | None) -> AIMessage | Any | None: + if messages is None: + return None + for message in reversed(list(messages)): + if isinstance(message, AIMessage): + return message + if getattr(message, "type", None) == "ai": + return message + return None + + +def final_assistant_parts_from_messages(messages: Iterable[Any] | None) -> list[dict[str, Any]]: + return normalize_ai_message_to_parts(last_ai_message(messages)) + + +def has_non_empty_text_part(parts: Iterable[dict[str, Any]]) -> bool: + return any( + part.get("type") == "text" + and isinstance(part.get("text"), str) + and bool(part.get("text", "").strip()) + for part in parts + ) + + +def merge_streamed_and_final_parts( + streamed_parts: list[dict[str, Any]], + final_parts: list[dict[str, Any]], +) -> list[dict[str, Any]]: + """Use final-state text only when streaming captured no answer text.""" + if has_non_empty_text_part(streamed_parts): + return streamed_parts + if not has_non_empty_text_part(final_parts): + return streamed_parts + return [*streamed_parts, *final_parts] + diff --git a/surfsense_backend/app/tasks/chat/streaming/agent/event_loop.py b/surfsense_backend/app/tasks/chat/streaming/agent/event_loop.py index d96144bcd..fb7548818 100644 --- a/surfsense_backend/app/tasks/chat/streaming/agent/event_loop.py +++ b/surfsense_backend/app/tasks/chat/streaming/agent/event_loop.py @@ -25,6 +25,9 @@ from app.tasks.chat.streaming.graph_stream.event_stream import stream_output from app.tasks.chat.streaming.helpers.interrupt_inspector import ( all_interrupt_values, ) +from app.tasks.chat.message_parts_normalizer import ( + final_assistant_parts_from_messages, +) from app.tasks.chat.streaming.shared.stream_result import StreamResult from app.tasks.chat.streaming.shared.utils import safe_float from app.utils.perf import get_perf_logger @@ -75,6 +78,9 @@ async def stream_agent_events( state = await agent.aget_state(config) state_values = getattr(state, "values", {}) or {} + result.final_message_parts = final_assistant_parts_from_messages( + state_values.get("messages") + ) # Safety net: if astream_events was cancelled before # KnowledgeBasePersistenceMiddleware.aafter_agent ran, any staged work diff --git a/surfsense_backend/app/tasks/chat/streaming/flows/shared/assistant_finalize.py b/surfsense_backend/app/tasks/chat/streaming/flows/shared/assistant_finalize.py index be1f102f3..3f767c60b 100644 --- a/surfsense_backend/app/tasks/chat/streaming/flows/shared/assistant_finalize.py +++ b/surfsense_backend/app/tasks/chat/streaming/flows/shared/assistant_finalize.py @@ -53,6 +53,7 @@ async def finalize_assistant_message( ): return + from app.tasks.chat.message_parts_normalizer import merge_streamed_and_final_parts from app.tasks.chat.persistence import finalize_assistant_turn builder_stats: dict[str, int] | None = None @@ -74,6 +75,10 @@ async def finalize_assistant_message( "text": stream_result.accumulated_text or "", } ] + content_payload = merge_streamed_and_final_parts( + content_payload, + stream_result.final_message_parts, + ) if builder_stats is not None: _perf_log.info( diff --git a/surfsense_backend/app/tasks/chat/streaming/shared/stream_result.py b/surfsense_backend/app/tasks/chat/streaming/shared/stream_result.py index a940e8a9f..5e164070a 100644 --- a/surfsense_backend/app/tasks/chat/streaming/shared/stream_result.py +++ b/surfsense_backend/app/tasks/chat/streaming/shared/stream_result.py @@ -35,3 +35,7 @@ class StreamResult: # (``StreamResult`` is logged in some error branches) from dumping a # potentially-large parts list. content_builder: Any | None = field(default=None, repr=False) + # User-visible assistant message parts derived from the final LangGraph + # state. Used after streaming completes as a provider-agnostic persistence + # backfill when no text chunks reached the live stream. + final_message_parts: list[dict[str, Any]] = field(default_factory=list) diff --git a/surfsense_backend/app/utils/content_utils.py b/surfsense_backend/app/utils/content_utils.py index 05a4610c7..aae936888 100644 --- a/surfsense_backend/app/utils/content_utils.py +++ b/surfsense_backend/app/utils/content_utils.py @@ -18,6 +18,11 @@ from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import selectinload +from app.tasks.chat.llm_history_normalizer import ( + assistant_content_to_llm_text, + user_content_to_llm_content, +) + if TYPE_CHECKING: from app.db import ChatVisibility @@ -95,17 +100,28 @@ async def bootstrap_history_from_db( langchain_messages: list[HumanMessage | AIMessage] = [] for msg in db_messages: - text_content = extract_text_content(msg.content) - if not text_content: - continue if msg.role == "user": + user_content = user_content_to_llm_content( + msg.content, + allow_images=False, + ) + if not user_content: + continue if is_shared: author_name = ( msg.author.display_name if msg.author else None ) or "A team member" - text_content = f"**[{author_name}]:** {text_content}" - langchain_messages.append(HumanMessage(content=text_content)) + if isinstance(user_content, str): + user_content = f"**[{author_name}]:** {user_content}" + elif user_content and user_content[0].get("type") == "text": + user_content[0] = { + **user_content[0], + "text": f"**[{author_name}]:** {user_content[0].get('text', '')}", + } + langchain_messages.append(HumanMessage(content=user_content)) elif msg.role == "assistant": - langchain_messages.append(AIMessage(content=text_content)) + assistant_text = assistant_content_to_llm_text(msg.content) + if assistant_text: + langchain_messages.append(AIMessage(content=assistant_text)) return langchain_messages diff --git a/surfsense_backend/tests/unit/agents/chat/runtime/test_llm_config_sanitizer.py b/surfsense_backend/tests/unit/agents/chat/runtime/test_llm_config_sanitizer.py new file mode 100644 index 000000000..00d38b4b1 --- /dev/null +++ b/surfsense_backend/tests/unit/agents/chat/runtime/test_llm_config_sanitizer.py @@ -0,0 +1,40 @@ +"""Regression tests for model-boundary message sanitization.""" + +from __future__ import annotations + +import pytest +from langchain_core.messages import AIMessage + +from app.agents.chat.runtime.llm_config import _sanitize_messages + +pytestmark = pytest.mark.unit + + +def test_sanitize_messages_strips_provider_specific_thinking_blocks() -> None: + original = AIMessage( + content=[ + {"type": "thinking", "thinking": "private reasoning"}, + {"type": "text", "text": "visible answer"}, + ] + ) + + sanitized = _sanitize_messages([original]) + + assert sanitized[0].content == "visible answer" + assert original.content == [ + {"type": "thinking", "thinking": "private reasoning"}, + {"type": "text", "text": "visible answer"}, + ] + + +def test_sanitize_messages_sets_tool_only_ai_content_to_none() -> None: + message = AIMessage( + content="", + tool_calls=[{"name": "search", "args": {"q": "x"}, "id": "call_1"}], + ) + + sanitized = _sanitize_messages([message]) + + assert sanitized[0].content is None + assert message.content == "" + diff --git a/surfsense_backend/tests/unit/tasks/chat/test_llm_history_normalizer.py b/surfsense_backend/tests/unit/tasks/chat/test_llm_history_normalizer.py new file mode 100644 index 000000000..14bb030cd --- /dev/null +++ b/surfsense_backend/tests/unit/tasks/chat/test_llm_history_normalizer.py @@ -0,0 +1,62 @@ +"""Unit tests for provider-safe LLM history normalization.""" + +from __future__ import annotations + +import pytest + +from app.tasks.chat.llm_history_normalizer import ( + assistant_content_to_llm_text, + user_content_to_llm_content, +) + +pytestmark = pytest.mark.unit + + +def test_assistant_ui_parts_drop_thinking_steps_for_llm_history() -> None: + content = [ + {"type": "data-thinking-steps", "data": [{"id": "thinking-1"}]}, + {"type": "text", "text": "visible answer"}, + ] + + assert assistant_content_to_llm_text(content) == "visible answer" + + +def test_provider_thinking_blocks_are_not_replayed_to_llm() -> None: + content = [ + {"type": "thinking", "thinking": "private reasoning"}, + {"type": "text", "text": "final answer"}, + ] + + assert assistant_content_to_llm_text(content) == "final answer" + + +def test_unknown_assistant_blocks_are_dropped() -> None: + content = [ + {"type": "redacted_thinking", "data": "hidden"}, + {"type": "tool_use", "name": "search"}, + {"type": "text", "text": "kept"}, + ] + + assert assistant_content_to_llm_text(content) == "kept" + + +def test_user_images_convert_to_openai_compatible_image_url_blocks() -> None: + content = [ + {"type": "text", "text": "look"}, + {"type": "image", "image": "data:image/png;base64,abc"}, + ] + + assert user_content_to_llm_content(content, allow_images=True) == [ + {"type": "text", "text": "look"}, + {"type": "image_url", "image_url": {"url": "data:image/png;base64,abc"}}, + ] + + +def test_user_images_can_be_dropped_for_text_only_history() -> None: + content = [ + {"type": "text", "text": "look"}, + {"type": "image", "image": "data:image/png;base64,abc"}, + ] + + assert user_content_to_llm_content(content, allow_images=False) == "look" + diff --git a/surfsense_backend/tests/unit/tasks/chat/test_message_parts_normalizer.py b/surfsense_backend/tests/unit/tasks/chat/test_message_parts_normalizer.py new file mode 100644 index 000000000..173d03ed5 --- /dev/null +++ b/surfsense_backend/tests/unit/tasks/chat/test_message_parts_normalizer.py @@ -0,0 +1,68 @@ +"""Unit tests for final assistant message part normalization.""" + +from __future__ import annotations + +import pytest +from langchain_core.messages import AIMessage, HumanMessage, ToolMessage + +from app.tasks.chat.message_parts_normalizer import ( + final_assistant_parts_from_messages, + merge_streamed_and_final_parts, + normalize_ai_message_to_parts, +) + +pytestmark = pytest.mark.unit + + +def test_string_ai_message_content_becomes_text_part() -> None: + assert normalize_ai_message_to_parts(AIMessage(content="hello")) == [ + {"type": "text", "text": "hello"} + ] + + +def test_deepseek_thinking_plus_text_blocks_backfill_only_text() -> None: + message = AIMessage( + content=[ + {"type": "thinking", "thinking": "hidden reasoning"}, + {"type": "text", "text": "Yo bro! What's up?"}, + ], + additional_kwargs={"reasoning_content": "hidden reasoning"}, + ) + + assert normalize_ai_message_to_parts(message) == [ + {"type": "text", "text": "Yo bro! What's up?"} + ] + + +def test_final_parts_use_last_ai_message_and_skip_trailing_tool_messages() -> None: + messages = [ + HumanMessage(content="ask"), + AIMessage(content="draft"), + ToolMessage(content="tool output", tool_call_id="tc-1"), + AIMessage(content=[{"type": "text", "text": "final answer"}]), + ToolMessage(content="trailing tool noise", tool_call_id="tc-2"), + ] + + assert final_assistant_parts_from_messages(messages) == [ + {"type": "text", "text": "final answer"} + ] + + +def test_merge_adds_final_text_when_stream_only_has_thinking_steps() -> None: + streamed = [ + { + "type": "data-thinking-steps", + "data": [{"id": "thinking-1", "status": "completed"}], + } + ] + final = [{"type": "text", "text": "visible answer"}] + + assert merge_streamed_and_final_parts(streamed, final) == [*streamed, *final] + + +def test_merge_does_not_duplicate_when_stream_already_has_text() -> None: + streamed = [{"type": "text", "text": "streamed answer"}] + final = [{"type": "text", "text": "final answer"}] + + assert merge_streamed_and_final_parts(streamed, final) == streamed + From 610ff063d6afab5b573cf5aab8657771f3a7bca6 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Fri, 12 Jun 2026 02:17:51 +0530 Subject: [PATCH 24/59] refactor(model-connections): update frontend for provider-based models --- .../[search_space_id]/client-layout.tsx | 2 +- .../[search_space_id]/onboard/page.tsx | 2 +- .../model-connections-query.atoms.ts | 7 + .../components/assistant-ui/thread.tsx | 33 ++-- .../components/new-chat/model-selector.tsx | 13 +- .../settings/model-connections-settings.tsx | 182 ++++++++++-------- .../types/model-connections.types.ts | 44 +++-- .../hooks/use-automation-eligible-models.ts | 11 +- .../lib/apis/model-connections-api.service.ts | 6 + surfsense_web/lib/query-client/cache-keys.ts | 1 + 10 files changed, 177 insertions(+), 124 deletions(-) diff --git a/surfsense_web/app/dashboard/[search_space_id]/client-layout.tsx b/surfsense_web/app/dashboard/[search_space_id]/client-layout.tsx index c7e05fe99..2b16a038a 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/client-layout.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/client-layout.tsx @@ -49,7 +49,7 @@ export function DashboardClientLayout({ const firstGlobalChatModel = useMemo(() => { for (const connection of globalConnections) { - const model = connection.models.find((item) => item.enabled && item.capabilities?.chat); + const model = connection.models.find((item) => item.enabled && item.supports_chat); if (model) return model; } return null; diff --git a/surfsense_web/app/dashboard/[search_space_id]/onboard/page.tsx b/surfsense_web/app/dashboard/[search_space_id]/onboard/page.tsx index 9cf429a3a..c6fc1c7a2 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/onboard/page.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/onboard/page.tsx @@ -33,7 +33,7 @@ export default function OnboardPage() { const firstGlobalChatModel = useMemo(() => { for (const connection of globalConnections) { - const model = connection.models.find((item) => item.enabled && item.capabilities?.chat); + const model = connection.models.find((item) => item.enabled && item.supports_chat); if (model) return model; } return null; diff --git a/surfsense_web/atoms/model-connections/model-connections-query.atoms.ts b/surfsense_web/atoms/model-connections/model-connections-query.atoms.ts index 617ffe124..87f31ce9b 100644 --- a/surfsense_web/atoms/model-connections/model-connections-query.atoms.ts +++ b/surfsense_web/atoms/model-connections/model-connections-query.atoms.ts @@ -11,6 +11,13 @@ export const globalModelConnectionsAtom = atomWithQuery(() => ({ queryFn: () => modelConnectionsApiService.getGlobalConnections(), })); +export const modelProvidersAtom = atomWithQuery(() => ({ + queryKey: cacheKeys.modelConnections.providers(), + enabled: !!getBearerToken(), + staleTime: 60 * 60 * 1000, + queryFn: () => modelConnectionsApiService.getModelProviders(), +})); + export const modelConnectionsAtom = atomWithQuery((get) => { const searchSpaceId = Number(get(activeSearchSpaceIdAtom)); return { diff --git a/surfsense_web/components/assistant-ui/thread.tsx b/surfsense_web/components/assistant-ui/thread.tsx index 5796109f0..722ebb476 100644 --- a/surfsense_web/components/assistant-ui/thread.tsx +++ b/surfsense_web/components/assistant-ui/thread.tsx @@ -48,10 +48,10 @@ import { connectorDialogOpenAtom } from "@/atoms/connector-dialog/connector-dial import { connectorsAtom } from "@/atoms/connectors/connector-query.atoms"; import { membersAtom } from "@/atoms/members/members-query.atoms"; import { - globalNewLLMConfigsAtom, - llmPreferencesAtom, - newLLMConfigsAtom, -} from "@/atoms/new-llm-config/new-llm-config-query.atoms"; + globalModelConnectionsAtom, + modelConnectionsAtom, + modelRolesAtom, +} from "@/atoms/model-connections/model-connections-query.atoms"; import { currentUserAtom } from "@/atoms/user/user-query.atoms"; import { AssistantMessage } from "@/components/assistant-ui/assistant-message"; import { ChatSessionStatus } from "@/components/assistant-ui/chat-session-status"; @@ -976,9 +976,9 @@ const ComposerAction: FC = ({ isBlockedByOtherUser = false if (url) setPendingScreenImages((prev) => [...prev, url]); }, [electronAPI, setPendingScreenImages]); - const { data: userConfigs } = useAtomValue(newLLMConfigsAtom); - const { data: globalConfigs } = useAtomValue(globalNewLLMConfigsAtom); - const { data: preferences } = useAtomValue(llmPreferencesAtom); + const { data: globalModelConnections } = useAtomValue(globalModelConnectionsAtom); + const { data: modelConnections } = useAtomValue(modelConnectionsAtom); + const { data: modelRoles } = useAtomValue(modelRolesAtom); const { data: agentTools } = useAtomValue(agentToolsAtom); const disabledTools = useAtomValue(disabledToolsAtom); @@ -1065,15 +1065,18 @@ const ComposerAction: FC = ({ isBlockedByOtherUser = false }, [hydrateDisabled]); const hasModelConfigured = useMemo(() => { - if (!preferences) return false; - const agentLlmId = preferences.agent_llm_id; - if (agentLlmId === null || agentLlmId === undefined) return false; - - if (agentLlmId <= 0) { - return globalConfigs?.some((c) => c.id === agentLlmId) ?? false; + const chatModelId = modelRoles?.chat_model_id ?? 0; + if (chatModelId === 0) { + return [...(globalModelConnections ?? []), ...(modelConnections ?? [])].some((connection) => + connection.models.some((model) => model.enabled && Boolean(model.supports_chat)) + ); } - return userConfigs?.some((c) => c.id === agentLlmId) ?? false; - }, [preferences, globalConfigs, userConfigs]); + return [...(globalModelConnections ?? []), ...(modelConnections ?? [])].some((connection) => + connection.models.some( + (model) => model.id === chatModelId && model.enabled && Boolean(model.supports_chat) + ) + ); + }, [modelRoles?.chat_model_id, globalModelConnections, modelConnections]); const isSendDisabled = isComposerEmpty || !hasModelConfigured || isBlockedByOtherUser; diff --git a/surfsense_web/components/new-chat/model-selector.tsx b/surfsense_web/components/new-chat/model-selector.tsx index 4744da617..6850096d6 100644 --- a/surfsense_web/components/new-chat/model-selector.tsx +++ b/surfsense_web/components/new-chat/model-selector.tsx @@ -56,18 +56,18 @@ function modelName(model: ModelRead) { function connectionLabel(connection: ConnectionRead) { if (connection.scope === "GLOBAL") return "Hosted"; - return connection.litellm_provider || connection.protocol; + return connection.provider; } function flattenChatModels(connections: ConnectionRead[]) { return connections.flatMap((connection) => connection.models - .filter((model) => model.enabled && Boolean(model.capabilities?.chat)) + .filter((model) => model.enabled && Boolean(model.supports_chat)) .map((model) => ({ ...model, connectionId: connection.id, connectionLabel: connectionLabel(connection), - provider: connection.litellm_provider || connection.protocol, + provider: connection.provider, })) ); } @@ -184,9 +184,14 @@ export function ModelSelector({ {modelName(model)}
{model.model_id}
+ {model.max_input_tokens ? ( +
+ {model.max_input_tokens.toLocaleString()} context +
+ ) : null}
- {!model.capabilities?.vision ? ( + {!model.supports_image_input ? ( No image diff --git a/surfsense_web/components/settings/model-connections-settings.tsx b/surfsense_web/components/settings/model-connections-settings.tsx index 29501abda..0e541548b 100644 --- a/surfsense_web/components/settings/model-connections-settings.tsx +++ b/surfsense_web/components/settings/model-connections-settings.tsx @@ -1,11 +1,12 @@ "use client"; import { useAtom, useAtomValue } from "jotai"; -import { CheckCircle2, PlugZap, Plus, RefreshCcw, XCircle } from "lucide-react"; +import { CheckCircle2, PlugZap, Plus, RefreshCcw, Trash2, XCircle } from "lucide-react"; import { useState } from "react"; import { addManualModelMutationAtom, createModelConnectionMutationAtom, + deleteModelConnectionMutationAtom, discoverConnectionModelsMutationAtom, testModelMutationAtom, updateModelConnectionMutationAtom, @@ -16,6 +17,7 @@ import { import { globalModelConnectionsAtom, modelConnectionsAtom, + modelProvidersAtom, modelRolesAtom, } from "@/atoms/model-connections/model-connections-query.atoms"; import { Badge } from "@/components/ui/badge"; @@ -30,37 +32,9 @@ import { SelectTrigger, SelectValue, } from "@/components/ui/select"; -import type { - ConnectionProtocol, - ConnectionRead, - ModelRead, -} from "@/contracts/types/model-connections.types"; +import type { ConnectionRead, ModelRead } from "@/contracts/types/model-connections.types"; import { getProviderIcon } from "@/lib/provider-icons"; -const PROTOCOL_OPTIONS: { value: ConnectionProtocol; label: string; description: string }[] = [ - { - value: "OPENAI_COMPATIBLE", - label: "OpenAI-compatible", - description: "Use for OpenAI, OpenRouter, Groq, vLLM, LM Studio, and compatible APIs.", - }, - { - value: "ANTHROPIC", - label: "Anthropic", - description: "Use for Claude endpoints that require Anthropic headers.", - }, - { - value: "OLLAMA", - label: "Ollama", - description: "Use for Ollama's native API.", - }, -]; - -function defaultLitellmProvider(protocol: ConnectionProtocol) { - if (protocol === "OLLAMA") return "ollama_chat"; - if (protocol === "ANTHROPIC") return "anthropic"; - return "openai"; -} - // Free-text URL hints (datalist), mirroring OpenWebUI. These never restrict // what the user can type — any OpenAI-compatible endpoint works. const URL_SUGGESTIONS = [ @@ -82,9 +56,19 @@ function modelLabel(model: ModelRead) { } function capability(model: ModelRead, key: "chat" | "vision" | "image_gen") { - return Boolean(model.capabilities?.[key]); + if (key === "chat") return Boolean(model.supports_chat); + if (key === "vision") return Boolean(model.supports_image_input); + return Boolean(model.supports_image_generation); } +type ModelCapabilityFilter = "chat" | "vision" | "image_gen"; + +const MODEL_CAPABILITY_FILTERS: { key: ModelCapabilityFilter; label: string }[] = [ + { key: "chat", label: "Chat" }, + { key: "vision", label: "Vision" }, + { key: "image_gen", label: "Image" }, +]; + function StatusBadge({ connection }: { connection: ConnectionRead }) { if (connection.last_status === "OK") { return ( @@ -107,9 +91,9 @@ function flattenModels(connections: ConnectionRead[]) { return connections.flatMap((connection) => connection.models.map((model) => ({ ...model, - connectionName: connection.litellm_provider || connection.protocol, + connectionName: connection.provider, connectionId: connection.id, - provider: connection.litellm_provider || connection.protocol, + provider: connection.provider, })) ); } @@ -118,6 +102,7 @@ function ConnectionCard({ connection }: { connection: ConnectionRead }) { const verifyConnection = useAtomValue(verifyModelConnectionMutationAtom); const discoverModels = useAtomValue(discoverConnectionModelsMutationAtom); const updateConnection = useAtomValue(updateModelConnectionMutationAtom); + const deleteConnection = useAtomValue(deleteModelConnectionMutationAtom); const addManualModel = useAtomValue(addManualModelMutationAtom); const updateModel = useAtomValue(updateModelMutationAtom); const testModel = useAtomValue(testModelMutationAtom); @@ -127,9 +112,16 @@ function ConnectionCard({ connection }: { connection: ConnectionRead }) { : []; const [allowlistText, setAllowlistText] = useState(allowlist.join(", ")); const [manualModelId, setManualModelId] = useState(""); + const [modelFilter, setModelFilter] = useState(null); - const providerLabel = connection.litellm_provider || connection.protocol; - const isLocal = connection.protocol === "OLLAMA" || !connection.base_url?.startsWith("https"); + const providerLabel = connection.provider; + const isLocal = + connection.provider === "ollama_chat" || + connection.provider === "lm_studio" || + !connection.base_url?.startsWith("https"); + const filteredModels = modelFilter + ? connection.models.filter((model) => capability(model, modelFilter)) + : connection.models; function saveAllowlist() { const ids = allowlistText @@ -151,6 +143,14 @@ function ConnectionCard({ connection }: { connection: ConnectionRead }) { ); } + function deleteCurrentConnection() { + const confirmed = window.confirm( + `Delete the ${providerLabel} connection and all of its models? This cannot be undone.` + ); + if (!confirmed) return; + deleteConnection.mutate(connection.id); + } + return (
@@ -175,6 +175,14 @@ function ConnectionCard({ connection }: { connection: ConnectionRead }) { +
@@ -232,8 +240,38 @@ function ConnectionCard({ connection }: { connection: ConnectionRead }) {
+ {connection.models.length > 0 ? ( +
+ Filter models + {MODEL_CAPABILITY_FILTERS.map((filter) => { + const count = connection.models.filter((model) => capability(model, filter.key)).length; + const isActive = modelFilter === filter.key; + + return ( + + ); + })} +
+ ) : null} +
- {connection.models.map((model) => ( + {filteredModels.length === 0 && modelFilter ? ( +
+ No {MODEL_CAPABILITY_FILTERS.find((filter) => filter.key === modelFilter)?.label.toLowerCase()}{" "} + models found on this connection. +
+ ) : null} + {filteredModels.map((model) => (
{["chat", "vision", "image_gen"] - .filter((key) => Boolean(model.capabilities?.[key])) - .join(", ") || "No verified capabilities"} + .filter((key) => capability(model, key as "chat" | "vision" | "image_gen")) + .join(", ") || "No discovered capabilities"}
@@ -278,18 +316,16 @@ function ConnectionCard({ connection }: { connection: ConnectionRead }) { export function ModelConnectionsSettings({ searchSpaceId }: { searchSpaceId: number }) { const [{ data: globalConnections = [] }] = useAtom(globalModelConnectionsAtom); const [{ data: connections = [] }] = useAtom(modelConnectionsAtom); + const [{ data: providers = [] }] = useAtom(modelProvidersAtom); const [{ data: roles }] = useAtom(modelRolesAtom); const createConnection = useAtomValue(createModelConnectionMutationAtom); const updateRoles = useAtomValue(updateModelRolesMutationAtom); - const [protocol, setProtocol] = useState("OPENAI_COMPATIBLE"); + const [provider, setProvider] = useState("openai_compatible"); const [baseUrl, setBaseUrl] = useState(""); const [apiKey, setApiKey] = useState(""); - const [litellmProvider, setLitellmProvider] = useState(""); - const [showAdvancedProvider, setShowAdvancedProvider] = useState(false); - const selectedProtocol = PROTOCOL_OPTIONS.find((item) => item.value === protocol); - const protocolDefaultProvider = defaultLitellmProvider(protocol); - const isOllama = protocol === "OLLAMA"; + const selectedProvider = providers.find((item) => item.provider === provider); + const isOllama = provider === "ollama_chat"; const allConnections = [...globalConnections, ...connections]; const enabledModels = flattenModels(allConnections).filter((model) => model.enabled); @@ -298,11 +334,9 @@ export function ModelConnectionsSettings({ searchSpaceId }: { searchSpaceId: num const imageModels = enabledModels.filter((model) => capability(model, "image_gen")); function handleCreate() { - const explicitProvider = litellmProvider.trim(); createConnection.mutate( { - protocol, - litellm_provider: explicitProvider ? explicitProvider : null, + provider, base_url: baseUrl || null, api_key: apiKey || null, scope: "SEARCH_SPACE", @@ -337,18 +371,22 @@ export function ModelConnectionsSettings({ searchSpaceId }: { searchSpaceId: num
- + setLitellmProvider(event.target.value)} - placeholder={protocolDefaultProvider} - /> -

- Leave empty to use the protocol default. Set this for more accurate LiteLLM - capabilities/costs, for example openrouter, groq, gemini, or azure. -

-
- ) : null} -
diff --git a/surfsense_web/contracts/types/model-connections.types.ts b/surfsense_web/contracts/types/model-connections.types.ts index 7a37799c4..a34687d74 100644 --- a/surfsense_web/contracts/types/model-connections.types.ts +++ b/surfsense_web/contracts/types/model-connections.types.ts @@ -1,26 +1,19 @@ import { z } from "zod"; -export const connectionProtocolEnum = z.enum(["OLLAMA", "OPENAI_COMPATIBLE", "ANTHROPIC"]); export const connectionScopeEnum = z.enum(["GLOBAL", "SEARCH_SPACE", "USER"]); export const modelSourceEnum = z.enum(["DISCOVERED", "MANUAL"]); -export const modelCapabilities = z.object({ - chat: z.boolean().optional(), - vision: z.boolean().optional(), - image_gen: z.boolean().optional(), - embedding: z.boolean().optional(), - tools: z.boolean().optional(), -}); - export const modelRead = z.object({ id: z.number(), connection_id: z.number(), model_id: z.string(), display_name: z.string().nullable().optional(), source: z.union([modelSourceEnum, z.string()]), - capabilities: z.record(z.string(), z.any()).default({}), - capabilities_declared: z.record(z.string(), z.any()).default({}), - capabilities_verified: z.record(z.string(), z.any()).default({}), + supports_chat: z.boolean().nullable().optional(), + max_input_tokens: z.number().nullable().optional(), + supports_image_input: z.boolean().nullable().optional(), + supports_tools: z.boolean().nullable().optional(), + supports_image_generation: z.boolean().nullable().optional(), capabilities_override: z.record(z.string(), z.any()).default({}), embedding_dimension: z.number().nullable().optional(), enabled: z.boolean(), @@ -31,8 +24,7 @@ export const modelRead = z.object({ export const connectionRead = z.object({ id: z.number(), - protocol: z.union([connectionProtocolEnum, z.string()]), - litellm_provider: z.string().nullable().optional(), + provider: z.string(), base_url: z.string().nullable().optional(), extra: z.record(z.string(), z.any()).default({}), scope: z.union([connectionScopeEnum, z.string()]), @@ -48,8 +40,7 @@ export const connectionRead = z.object({ }); export const connectionCreateRequest = z.object({ - protocol: connectionProtocolEnum, - litellm_provider: z.string().nullable().optional(), + provider: z.string().min(1), base_url: z.string().nullable().optional(), api_key: z.string().nullable().optional(), extra: z.record(z.string(), z.any()).default({}), @@ -59,7 +50,7 @@ export const connectionCreateRequest = z.object({ }); export const connectionUpdateRequest = z.object({ - litellm_provider: z.string().nullable().optional(), + provider: z.string().nullable().optional(), base_url: z.string().nullable().optional(), api_key: z.string().nullable().optional(), extra: z.record(z.string(), z.any()).optional(), @@ -74,6 +65,11 @@ export const modelCreateRequest = z.object({ export const modelUpdateRequest = z.object({ display_name: z.string().nullable().optional(), enabled: z.boolean().optional(), + supports_chat: z.boolean().nullable().optional(), + max_input_tokens: z.number().nullable().optional(), + supports_image_input: z.boolean().nullable().optional(), + supports_tools: z.boolean().nullable().optional(), + supports_image_generation: z.boolean().nullable().optional(), capabilities_override: z.record(z.string(), z.any()).optional(), }); @@ -89,10 +85,21 @@ export const modelRoles = z.object({ image_gen_model_id: z.number().nullable().optional(), }); +export const modelProviderRead = z.object({ + provider: z.string(), + transport: z.string(), + discovery: z.string(), + default_base_url: z.string().nullable().optional(), + base_url_required: z.boolean(), + auth_style: z.string(), + local_only: z.boolean().default(false), +}); + +export const modelProviderListResponse = z.array(modelProviderRead); + export const connectionListResponse = z.array(connectionRead); export const modelListResponse = z.array(modelRead); -export type ConnectionProtocol = z.infer; export type ConnectionScope = z.infer; export type ModelRead = z.infer; export type ConnectionRead = z.infer; @@ -102,3 +109,4 @@ export type ModelCreateRequest = z.infer; export type ModelUpdateRequest = z.infer; export type ModelRoles = z.infer; export type VerifyConnectionResponse = z.infer; +export type ModelProviderRead = z.infer; diff --git a/surfsense_web/hooks/use-automation-eligible-models.ts b/surfsense_web/hooks/use-automation-eligible-models.ts index f8b264162..fd3ad3a6a 100644 --- a/surfsense_web/hooks/use-automation-eligible-models.ts +++ b/surfsense_web/hooks/use-automation-eligible-models.ts @@ -47,11 +47,16 @@ function buildKind( capability: "chat" | "image_gen" | "vision", prefId: number | null | undefined ): EligibleModelKind { + const supportsCapability = (model: ModelRead) => { + if (capability === "chat") return Boolean(model.supports_chat); + if (capability === "vision") return Boolean(model.supports_image_input); + return Boolean(model.supports_image_generation); + }; const toOption = (connection: ConnectionRead, model: ModelRead, isBYOK: boolean) => ({ id: model.id, name: model.display_name || model.model_id, modelName: model.model_id, - provider: connection.litellm_provider || connection.protocol, + provider: connection.provider, isBYOK, }); @@ -60,7 +65,7 @@ function buildKind( .filter( (model) => model.enabled && - Boolean(model.capabilities?.[capability]) && + supportsCapability(model) && String(model.billing_tier ?? "").toLowerCase() === "premium" ) .map((model) => toOption(connection, model, false)) @@ -68,7 +73,7 @@ function buildKind( const byokOptions: EligibleModelOption[] = (byok ?? []).flatMap((connection) => connection.models - .filter((model) => model.enabled && Boolean(model.capabilities?.[capability])) + .filter((model) => model.enabled && supportsCapability(model)) .map((model) => toOption(connection, model, true)) ); diff --git a/surfsense_web/lib/apis/model-connections-api.service.ts b/surfsense_web/lib/apis/model-connections-api.service.ts index 92eec8e61..12ad8e0d2 100644 --- a/surfsense_web/lib/apis/model-connections-api.service.ts +++ b/surfsense_web/lib/apis/model-connections-api.service.ts @@ -7,10 +7,12 @@ import { connectionRead, connectionUpdateRequest, type ModelCreateRequest, + type ModelProviderRead, type ModelRead, type ModelRoles, type ModelUpdateRequest, modelCreateRequest, + modelProviderListResponse, modelListResponse, modelRead, modelRoles, @@ -26,6 +28,10 @@ class ModelConnectionsApiService { return baseApiService.get(`/api/v1/global-model-connections`, connectionListResponse); }; + getModelProviders = async (): Promise => { + return baseApiService.get(`/api/v1/model-providers`, modelProviderListResponse); + }; + getConnections = async (searchSpaceId: number): Promise => { return baseApiService.get( `/api/v1/model-connections?search_space_id=${searchSpaceId}`, diff --git a/surfsense_web/lib/query-client/cache-keys.ts b/surfsense_web/lib/query-client/cache-keys.ts index 558a73f95..5a3f0fb84 100644 --- a/surfsense_web/lib/query-client/cache-keys.ts +++ b/surfsense_web/lib/query-client/cache-keys.ts @@ -47,6 +47,7 @@ export const cacheKeys = { modelConnections: { all: (searchSpaceId: number) => ["model-connections", searchSpaceId] as const, global: () => ["model-connections", "global"] as const, + providers: () => ["model-connections", "providers"] as const, roles: (searchSpaceId: number) => ["model-roles", searchSpaceId] as const, }, imageGenConfigs: { From 5da3ab0552b6a03e2fcc0e5e7527b9210bd9424f Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Fri, 12 Jun 2026 03:20:09 +0530 Subject: [PATCH 25/59] feat(database): rename add_model_connections alembic migration --- ..._model_connections.py => 160_add_model_connections.py} | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) rename surfsense_backend/alembic/versions/{156_add_model_connections.py => 160_add_model_connections.py} (99%) diff --git a/surfsense_backend/alembic/versions/156_add_model_connections.py b/surfsense_backend/alembic/versions/160_add_model_connections.py similarity index 99% rename from surfsense_backend/alembic/versions/156_add_model_connections.py rename to surfsense_backend/alembic/versions/160_add_model_connections.py index 64614db99..49d6315ca 100644 --- a/surfsense_backend/alembic/versions/156_add_model_connections.py +++ b/surfsense_backend/alembic/versions/160_add_model_connections.py @@ -1,7 +1,7 @@ """add model connections -Revision ID: 156 -Revises: 155 +Revision ID: 160 +Revises: 159 """ from collections.abc import Sequence @@ -11,8 +11,8 @@ from sqlalchemy.dialects import postgresql from alembic import op -revision: str = "156" -down_revision: str | None = "155" +revision: str = "160" +down_revision: str | None = "159" branch_labels: str | Sequence[str] | None = None depends_on: str | Sequence[str] | None = None From aba95e4faf1ed277eb0b27ba9e4b649e976113f2 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Fri, 12 Jun 2026 03:35:49 +0530 Subject: [PATCH 26/59] feat(database): enhance podcast lifecycle management by adding temporary unpublishing during migration --- .../versions/158_evolve_podcasts_lifecycle.py | 28 ++++++++++ surfsense_backend/app/zero_publication.py | 51 ++++++++++++++++--- 2 files changed, 72 insertions(+), 7 deletions(-) diff --git a/surfsense_backend/alembic/versions/158_evolve_podcasts_lifecycle.py b/surfsense_backend/alembic/versions/158_evolve_podcasts_lifecycle.py index 15cf04f9d..7c51158a9 100644 --- a/surfsense_backend/alembic/versions/158_evolve_podcasts_lifecycle.py +++ b/surfsense_backend/alembic/versions/158_evolve_podcasts_lifecycle.py @@ -13,8 +13,34 @@ down_revision: str | None = "157" branch_labels: str | Sequence[str] | None = None depends_on: str | Sequence[str] | None = None +PUBLICATION_NAME = "zero_publication" + + +def _drop_podcasts_from_zero_publication() -> None: + """Temporarily unpublish podcasts while changing published columns.""" + + op.execute( + f""" + DO $$ + BEGIN + IF EXISTS ( + SELECT 1 + FROM pg_publication_tables + WHERE pubname = '{PUBLICATION_NAME}' + AND schemaname = current_schema() + AND tablename = 'podcasts' + ) THEN + ALTER PUBLICATION "{PUBLICATION_NAME}" DROP TABLE "podcasts"; + END IF; + END + $$; + """ + ) + def upgrade() -> None: + _drop_podcasts_from_zero_publication() + # Retype the status enum by swapping in a fresh type and casting existing # rows. The legacy transient value 'generating' maps onto 'rendering'. op.execute("ALTER TYPE podcast_status RENAME TO podcast_status_old;") @@ -57,6 +83,8 @@ def upgrade() -> None: def downgrade() -> None: + _drop_podcasts_from_zero_publication() + op.execute("ALTER TABLE podcasts DROP COLUMN IF EXISTS error;") op.execute("ALTER TABLE podcasts DROP COLUMN IF EXISTS duration_seconds;") op.execute("ALTER TABLE podcasts DROP COLUMN IF EXISTS storage_key;") diff --git a/surfsense_backend/app/zero_publication.py b/surfsense_backend/app/zero_publication.py index 139286ee6..869559c55 100644 --- a/surfsense_backend/app/zero_publication.py +++ b/surfsense_backend/app/zero_publication.py @@ -100,12 +100,32 @@ def _column_exists(conn: Connection, table: str, column: str) -> bool: ) -def _expected_columns(conn: Connection, table: str) -> list[str] | None: +def _table_exists(conn: Connection, table: str) -> bool: + return ( + conn.execute( + text( + "SELECT 1 FROM information_schema.tables " + "WHERE table_schema = current_schema() " + "AND table_name = :table" + ), + {"table": table}, + ).fetchone() + is not None + ) + + +def _expected_columns( + conn: Connection, table: str, *, include_missing_columns: bool = True +) -> list[str] | None: columns = ZERO_PUBLICATION[table] if columns is None: return None - expected = list(columns) + if include_missing_columns: + expected = list(columns) + else: + expected = [column for column in columns if _column_exists(conn, table, column)] + if table in {"documents", "user", "podcasts"} and _column_exists( conn, table, "_0_version" ): @@ -113,11 +133,20 @@ def _expected_columns(conn: Connection, table: str) -> list[str] | None: return expected -def _format_table_entry(conn: Connection, table: str) -> str: - columns = _expected_columns(conn, table) +def _format_table_entry( + conn: Connection, table: str, *, include_missing_columns: bool = True +) -> str | None: + if not include_missing_columns and not _table_exists(conn, table): + return None + + columns = _expected_columns( + conn, table, include_missing_columns=include_missing_columns + ) table_sql = _quote_identifier(table) if columns is None: return table_sql + if not include_missing_columns and not columns: + return None column_sql = ", ".join(_quote_identifier(column) for column in columns) return f"{table_sql} ({column_sql})" @@ -126,9 +155,17 @@ def _format_table_entry(conn: Connection, table: str) -> str: def build_set_table_sql(conn: Connection) -> str: """Build the canonical plain SET TABLE statement for Zero's event triggers.""" - table_list = ", ".join( - _format_table_entry(conn, table) for table in ZERO_PUBLICATION - ) + table_entries = [ + entry + for table in ZERO_PUBLICATION + if ( + entry := _format_table_entry( + conn, table, include_missing_columns=False + ) + ) + is not None + ] + table_list = ", ".join(table_entries) return f"ALTER PUBLICATION {_quote_identifier(PUBLICATION_NAME)} SET TABLE {table_list}" From 8e8cf96faa629e7f86ef60ed2f12ed61012b2a4f Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Fri, 12 Jun 2026 05:03:14 +0530 Subject: [PATCH 27/59] feat(error-handling): implement LLM error adaptation and classification for chat streaming - Introduced LLMErrorCategory and adapt_llm_exception to normalize LLM exceptions. - Updated llm_retryable_message and llm_permanent_message to utilize the new adaptation logic. - Enhanced classify_stream_exception to classify provider errors and return user-friendly messages. - Added tests for error classification and adaptation to ensure robustness. - Updated frontend error handling to display appropriate messages based on new classifications. --- .../app/indexing_pipeline/exceptions.py | 36 +-- .../app/routes/anonymous_chat_routes.py | 11 +- .../app/services/llm_error_adapter.py | 251 ++++++++++++++++++ .../tasks/chat/streaming/errors/classifier.py | 94 ++++++- .../chat/streaming/test_error_classifier.py | 80 ++++++ .../new-chat/[[...chat_id]]/page.tsx | 12 + .../components/free-chat/free-chat-page.tsx | 17 +- .../lib/chat/chat-error-classifier.ts | 66 ++++- surfsense_web/lib/chat/chat-request-errors.ts | 4 + 9 files changed, 533 insertions(+), 38 deletions(-) create mode 100644 surfsense_backend/app/services/llm_error_adapter.py create mode 100644 surfsense_backend/tests/unit/tasks/chat/streaming/test_error_classifier.py diff --git a/surfsense_backend/app/indexing_pipeline/exceptions.py b/surfsense_backend/app/indexing_pipeline/exceptions.py index 666fa4b9f..bf9d9e9fa 100644 --- a/surfsense_backend/app/indexing_pipeline/exceptions.py +++ b/surfsense_backend/app/indexing_pipeline/exceptions.py @@ -14,6 +14,8 @@ from litellm.exceptions import ( ) from sqlalchemy.exc import IntegrityError as IntegrityError +from app.services.llm_error_adapter import LLMErrorCategory, adapt_llm_exception + # Tuples for use directly in except clauses. RETRYABLE_LLM_ERRORS = ( RateLimitError, @@ -97,38 +99,20 @@ def safe_exception_message(exc: Exception) -> str: def llm_retryable_message(exc: Exception) -> str: try: - if isinstance(exc, RateLimitError): - return PipelineMessages.RATE_LIMIT - if isinstance(exc, Timeout): - return PipelineMessages.LLM_TIMEOUT - if isinstance(exc, ServiceUnavailableError): - return PipelineMessages.LLM_UNAVAILABLE - if isinstance(exc, BadGatewayError): - return PipelineMessages.LLM_BAD_GATEWAY - if isinstance(exc, InternalServerError): - return PipelineMessages.LLM_SERVER_ERROR - if isinstance(exc, APIConnectionError): - return PipelineMessages.LLM_CONNECTION - return safe_exception_message(exc) + adapted = adapt_llm_exception(exc) + if adapted.category is LLMErrorCategory.UNKNOWN: + return safe_exception_message(exc) + return adapted.user_message except Exception: return "Something went wrong when calling the LLM." def llm_permanent_message(exc: Exception) -> str: try: - if isinstance(exc, AuthenticationError): - return PipelineMessages.LLM_AUTH - if isinstance(exc, PermissionDeniedError): - return PipelineMessages.LLM_PERMISSION - if isinstance(exc, NotFoundError): - return PipelineMessages.LLM_NOT_FOUND - if isinstance(exc, BadRequestError): - return PipelineMessages.LLM_BAD_REQUEST - if isinstance(exc, UnprocessableEntityError): - return PipelineMessages.LLM_UNPROCESSABLE - if isinstance(exc, APIResponseValidationError): - return PipelineMessages.LLM_RESPONSE - return safe_exception_message(exc) + adapted = adapt_llm_exception(exc) + if adapted.category is LLMErrorCategory.UNKNOWN: + return safe_exception_message(exc) + return adapted.user_message except Exception: return "Something went wrong when calling the LLM." diff --git a/surfsense_backend/app/routes/anonymous_chat_routes.py b/surfsense_backend/app/routes/anonymous_chat_routes.py index 84420e738..aa0e70464 100644 --- a/surfsense_backend/app/routes/anonymous_chat_routes.py +++ b/surfsense_backend/app/routes/anonymous_chat_routes.py @@ -18,6 +18,7 @@ from app.etl_pipeline.file_classifier import ( PLAINTEXT_EXTENSIONS, ) from app.rate_limiter import limiter +from app.tasks.chat.streaming.errors.classifier import classify_stream_exception logger = logging.getLogger(__name__) @@ -474,7 +475,15 @@ async def stream_anonymous_chat( except Exception as e: logger.exception("Anonymous chat stream error") await TokenQuotaService.anon_release(session_key, ip_key, request_id) - yield streaming_service.format_error(f"Error during chat: {e!s}") + _, error_code, _, _, user_message, extra = classify_stream_exception( + e, + flow_label="chat", + ) + yield streaming_service.format_error( + user_message, + error_code=error_code, + extra=extra, + ) yield streaming_service.format_done() finally: await TokenQuotaService.anon_release_stream_slot(client_ip) diff --git a/surfsense_backend/app/services/llm_error_adapter.py b/surfsense_backend/app/services/llm_error_adapter.py new file mode 100644 index 000000000..b0de15fb0 --- /dev/null +++ b/surfsense_backend/app/services/llm_error_adapter.py @@ -0,0 +1,251 @@ +"""Normalize provider/LLM exceptions into low-cardinality product categories.""" + +from __future__ import annotations + +import json +from dataclasses import dataclass +from enum import StrEnum +from typing import Any + + +class LLMErrorCategory(StrEnum): + RATE_LIMITED = "rate_limited" + TIMEOUT = "timeout" + PROVIDER_UNAVAILABLE = "provider_unavailable" + BAD_GATEWAY = "bad_gateway" + CONNECTION_FAILED = "connection_failed" + AUTH_FAILED = "auth_failed" + PERMISSION_DENIED = "permission_denied" + MODEL_NOT_FOUND = "model_not_found" + BAD_REQUEST = "bad_request" + CONTEXT_LIMIT = "context_limit" + RESPONSE_INVALID = "response_invalid" + SERVER_ERROR = "server_error" + UNKNOWN = "unknown" + + +@dataclass(frozen=True) +class LLMErrorAdaptation: + category: LLMErrorCategory + retryable: bool + user_message: str + provider_status_code: int | None = None + provider_error_type: str | None = None + + +_CATEGORY_MESSAGES: dict[LLMErrorCategory, str] = { + LLMErrorCategory.RATE_LIMITED: "LLM rate limit exceeded. Will retry on next sync.", + LLMErrorCategory.TIMEOUT: "LLM request timed out. Will retry on next sync.", + LLMErrorCategory.PROVIDER_UNAVAILABLE: "LLM service temporarily unavailable. Will retry on next sync.", + LLMErrorCategory.BAD_GATEWAY: "LLM gateway error. Will retry on next sync.", + LLMErrorCategory.CONNECTION_FAILED: "Could not reach the LLM service. Check network connectivity.", + LLMErrorCategory.AUTH_FAILED: "LLM authentication failed. Check your API key.", + LLMErrorCategory.PERMISSION_DENIED: "LLM request denied. Check your account permissions.", + LLMErrorCategory.MODEL_NOT_FOUND: "Model not found. Check your model configuration.", + LLMErrorCategory.BAD_REQUEST: "LLM rejected the request. Document content may be invalid.", + LLMErrorCategory.CONTEXT_LIMIT: "Document exceeds the LLM context window even after optimization.", + LLMErrorCategory.RESPONSE_INVALID: "LLM returned an invalid response.", + LLMErrorCategory.SERVER_ERROR: "LLM internal server error. Will retry on next sync.", + LLMErrorCategory.UNKNOWN: "Something went wrong when calling the LLM.", +} + +_RETRYABLE_CATEGORIES = { + LLMErrorCategory.RATE_LIMITED, + LLMErrorCategory.TIMEOUT, + LLMErrorCategory.PROVIDER_UNAVAILABLE, + LLMErrorCategory.BAD_GATEWAY, + LLMErrorCategory.CONNECTION_FAILED, + LLMErrorCategory.SERVER_ERROR, +} + +_CLASS_NAME_MAP: tuple[tuple[LLMErrorCategory, tuple[str, ...]], ...] = ( + ( + LLMErrorCategory.RATE_LIMITED, + ("RateLimitError", "TooManyRequests", "TooManyRequestsError"), + ), + (LLMErrorCategory.TIMEOUT, ("Timeout", "APITimeoutError", "TimeoutException")), + ( + LLMErrorCategory.PROVIDER_UNAVAILABLE, + ("ServiceUnavailableError", "ServiceUnavailable"), + ), + ( + LLMErrorCategory.BAD_GATEWAY, + ("BadGatewayError", "GatewayTimeoutError"), + ), + ( + LLMErrorCategory.CONNECTION_FAILED, + ("APIConnectionError", "ConnectError", "ConnectTimeout", "ReadTimeout"), + ), + ( + LLMErrorCategory.AUTH_FAILED, + ("AuthenticationError", "InvalidApiKey", "InvalidAPIKey", "InvalidApiKeyError"), + ), + (LLMErrorCategory.PERMISSION_DENIED, ("PermissionDeniedError", "ForbiddenError")), + (LLMErrorCategory.MODEL_NOT_FOUND, ("NotFoundError", "ModelNotFoundError")), + ( + LLMErrorCategory.CONTEXT_LIMIT, + ("ContextWindowExceeded", "ContextOverflow", "ContextLimit"), + ), + ( + LLMErrorCategory.RESPONSE_INVALID, + ("APIResponseValidationError", "ResponseValidationError"), + ), + ( + LLMErrorCategory.BAD_REQUEST, + ("BadRequestError", "InvalidRequestError", "UnprocessableEntityError"), + ), + (LLMErrorCategory.SERVER_ERROR, ("InternalServerError",)), +) + + +def _parse_error_payload(message: str) -> dict[str, Any] | None: + candidates = [message] + first_brace_idx = message.find("{") + if first_brace_idx >= 0: + candidates.append(message[first_brace_idx:]) + + for candidate in candidates: + try: + parsed = json.loads(candidate) + if isinstance(parsed, dict): + return parsed + except Exception: + continue + return None + + +def _class_names(exc: BaseException) -> tuple[str, ...]: + return tuple(cls.__name__ for cls in type(exc).__mro__) + + +def _category_from_class_name(exc: BaseException) -> LLMErrorCategory | None: + names = _class_names(exc) + for category, hints in _CLASS_NAME_MAP: + if any(any(hint in name for hint in hints) for name in names): + return category + return None + + +def _extract_provider_status_code(parsed: dict[str, Any] | None) -> int | None: + if not isinstance(parsed, dict): + return None + candidates: list[Any] = [parsed.get("code"), parsed.get("status")] + nested = parsed.get("error") + if isinstance(nested, dict): + candidates.extend([nested.get("code"), nested.get("status")]) + for value in candidates: + try: + if value is None: + continue + return int(value) + except Exception: + continue + return None + + +def _extract_provider_error_type(parsed: dict[str, Any] | None) -> str | None: + if not isinstance(parsed, dict): + return None + candidates: list[Any] = [parsed.get("type")] + nested = parsed.get("error") + if isinstance(nested, dict): + candidates.append(nested.get("type")) + for value in candidates: + if isinstance(value, str) and value: + return value + return None + + +def _category_from_provider_payload( + status_code: int | None, + provider_error_type: str | None, +) -> LLMErrorCategory | None: + if status_code == 429: + return LLMErrorCategory.RATE_LIMITED + if status_code == 401: + return LLMErrorCategory.AUTH_FAILED + if status_code == 403: + return LLMErrorCategory.PERMISSION_DENIED + if status_code == 404: + return LLMErrorCategory.MODEL_NOT_FOUND + if status_code in (400, 422): + return LLMErrorCategory.BAD_REQUEST + if status_code in (502, 504): + return LLMErrorCategory.BAD_GATEWAY + if status_code == 503: + return LLMErrorCategory.PROVIDER_UNAVAILABLE + if status_code is not None and status_code >= 500: + return LLMErrorCategory.SERVER_ERROR + + normalized_type = (provider_error_type or "").lower() + if normalized_type == "rate_limit_error": + return LLMErrorCategory.RATE_LIMITED + if normalized_type in {"authentication_error", "invalid_api_key", "invalid_api_key_error"}: + return LLMErrorCategory.AUTH_FAILED + if normalized_type in {"permission_denied", "forbidden"}: + return LLMErrorCategory.PERMISSION_DENIED + if normalized_type in {"not_found_error", "model_not_found"}: + return LLMErrorCategory.MODEL_NOT_FOUND + if normalized_type in {"context_length_exceeded", "context_window_exceeded"}: + return LLMErrorCategory.CONTEXT_LIMIT + return None + + +def _category_from_message(raw: str) -> LLMErrorCategory | None: + lowered = raw.lower() + if any(hint in lowered for hint in ("rate limit", "rate-limited", "temporarily rate-limited")): + return LLMErrorCategory.RATE_LIMITED + if any( + hint in lowered + for hint in ( + "invalid api key", + "invalid_api_key", + "authentication", + "unauthorized", + "user not found", + "api key is expired", + "expired api key", + ) + ): + return LLMErrorCategory.AUTH_FAILED + if "forbidden" in lowered or "permission denied" in lowered: + return LLMErrorCategory.PERMISSION_DENIED + if "model not found" in lowered: + return LLMErrorCategory.MODEL_NOT_FOUND + if any( + hint in lowered + for hint in ( + "context length", + "context window", + "maximum context", + "too many tokens", + ) + ): + return LLMErrorCategory.CONTEXT_LIMIT + return None + + +def adapt_llm_exception(exc: BaseException) -> LLMErrorAdaptation: + raw = str(exc) + parsed = _parse_error_payload(raw) + status_code = _extract_provider_status_code(parsed) + provider_error_type = _extract_provider_error_type(parsed) + + category = ( + _category_from_provider_payload(status_code, provider_error_type) + or _category_from_message(raw) + or _category_from_class_name(exc) + or LLMErrorCategory.UNKNOWN + ) + return LLMErrorAdaptation( + category=category, + retryable=category in _RETRYABLE_CATEGORIES, + user_message=_CATEGORY_MESSAGES[category], + provider_status_code=status_code, + provider_error_type=provider_error_type, + ) + + +def llm_error_message(exc: BaseException) -> str: + return adapt_llm_exception(exc).user_message + diff --git a/surfsense_backend/app/tasks/chat/streaming/errors/classifier.py b/surfsense_backend/app/tasks/chat/streaming/errors/classifier.py index 6b37df343..269143af2 100644 --- a/surfsense_backend/app/tasks/chat/streaming/errors/classifier.py +++ b/surfsense_backend/app/tasks/chat/streaming/errors/classifier.py @@ -12,6 +12,7 @@ from app.agents.chat.multi_agent_chat.main_agent.middleware.busy_mutex import ( is_cancel_requested, ) from app.agents.chat.runtime.errors import BusyError +from app.services.llm_error_adapter import LLMErrorCategory, adapt_llm_exception TURN_CANCELLING_INITIAL_DELAY_MS = 200 TURN_CANCELLING_BACKOFF_FACTOR = 2 @@ -102,6 +103,9 @@ def _extract_provider_error_code(parsed: dict[str, Any] | None) -> int | None: def is_provider_rate_limited(exc: BaseException) -> bool: """Return True if the exception looks like an upstream HTTP 429 / rate limit.""" + if adapt_llm_exception(exc).category is LLMErrorCategory.RATE_LIMITED: + return True + raw = str(exc) lowered = raw.lower() if "ratelimit" in type(exc).__name__.lower(): @@ -131,6 +135,84 @@ def is_provider_rate_limited(exc: BaseException) -> bool: ) +def _provider_error_extra(adapted: Any) -> dict[str, Any] | None: + extra: dict[str, Any] = {"provider_error_category": adapted.category.value} + if adapted.provider_status_code is not None: + extra["provider_status_code"] = adapted.provider_status_code + if adapted.provider_error_type: + extra["provider_error_type"] = adapted.provider_error_type + return extra + + +def _classify_provider_exception( + exc: Exception, +) -> tuple[ + str, str, Literal["info", "warn", "error"], bool, str, dict[str, Any] | None +] | None: + adapted = adapt_llm_exception(exc) + + if adapted.category is LLMErrorCategory.RATE_LIMITED: + return ( + "rate_limited", + "RATE_LIMITED", + "warn", + True, + "This model is temporarily rate-limited. Please try again in a few seconds or switch models.", + _provider_error_extra(adapted), + ) + + if adapted.category in { + LLMErrorCategory.AUTH_FAILED, + LLMErrorCategory.PERMISSION_DENIED, + }: + return ( + "model_auth_failed", + "MODEL_AUTH_FAILED", + "warn", + True, + "This model's API key is invalid or expired. Switch models, or update the API key.", + _provider_error_extra(adapted), + ) + + if adapted.category is LLMErrorCategory.MODEL_NOT_FOUND: + return ( + "model_not_found", + "MODEL_NOT_FOUND", + "warn", + True, + "The selected model is unavailable or no longer exists. Switch to another model and try again.", + _provider_error_extra(adapted), + ) + + if adapted.category is LLMErrorCategory.CONTEXT_LIMIT: + return ( + "model_context_limit", + "MODEL_CONTEXT_LIMIT", + "warn", + True, + "This request is too large for the selected model. Try a model with a larger context window or reduce the input.", + _provider_error_extra(adapted), + ) + + if adapted.category in { + LLMErrorCategory.TIMEOUT, + LLMErrorCategory.PROVIDER_UNAVAILABLE, + LLMErrorCategory.BAD_GATEWAY, + LLMErrorCategory.CONNECTION_FAILED, + LLMErrorCategory.SERVER_ERROR, + }: + return ( + "model_provider_unavailable", + "MODEL_PROVIDER_UNAVAILABLE", + "warn", + True, + "The selected model provider is temporarily unavailable. Please try again or switch models.", + _provider_error_extra(adapted), + ) + + return None + + def classify_stream_exception( exc: Exception, *, @@ -167,15 +249,9 @@ def classify_stream_exception( None, ) - if is_provider_rate_limited(exc): - return ( - "rate_limited", - "RATE_LIMITED", - "warn", - True, - "This model is temporarily rate-limited. Please try again in a few seconds or switch models.", - None, - ) + provider_classification = _classify_provider_exception(exc) + if provider_classification is not None: + return provider_classification return ( "server_error", diff --git a/surfsense_backend/tests/unit/tasks/chat/streaming/test_error_classifier.py b/surfsense_backend/tests/unit/tasks/chat/streaming/test_error_classifier.py new file mode 100644 index 000000000..48b07596c --- /dev/null +++ b/surfsense_backend/tests/unit/tasks/chat/streaming/test_error_classifier.py @@ -0,0 +1,80 @@ +from __future__ import annotations + +import pytest + +from app.services.llm_error_adapter import LLMErrorCategory, adapt_llm_exception +from app.tasks.chat.streaming.errors.classifier import classify_stream_exception + +pytestmark = pytest.mark.unit + + +def _exception_named(name: str, message: str) -> Exception: + return type(name, (Exception,), {})(message) + + +def test_adapter_classifies_authentication_error_by_class_name() -> None: + exc = _exception_named("AuthenticationError", "provider rejected credentials") + + adapted = adapt_llm_exception(exc) + + assert adapted.category is LLMErrorCategory.AUTH_FAILED + assert adapted.retryable is False + assert adapted.user_message == "LLM authentication failed. Check your API key." + + +def test_adapter_classifies_embedded_provider_401_payload() -> None: + exc = RuntimeError( + 'litellm.AuthenticationError: OpenrouterException - {"error":{"message":"User not found.","code":401}}' + ) + + adapted = adapt_llm_exception(exc) + + assert adapted.category is LLMErrorCategory.AUTH_FAILED + assert adapted.provider_status_code == 401 + + +def test_adapter_preserves_rate_limit_classification() -> None: + exc = RuntimeError('{"error":{"message":"Slow down","code":429}}') + + adapted = adapt_llm_exception(exc) + + assert adapted.category is LLMErrorCategory.RATE_LIMITED + assert adapted.retryable is True + + +def test_stream_classifier_maps_model_auth_to_stable_code() -> None: + exc = RuntimeError( + 'litellm.AuthenticationError: OpenrouterException - {"error":{"message":"User not found.","code":401}}' + ) + + kind, code, severity, expected, message, extra = classify_stream_exception( + exc, + flow_label="chat", + ) + + assert kind == "model_auth_failed" + assert code == "MODEL_AUTH_FAILED" + assert severity == "warn" + assert expected is True + assert "API key" in message + assert extra == { + "provider_error_category": "auth_failed", + "provider_status_code": 401, + } + + +def test_stream_classifier_keeps_unknown_errors_generic() -> None: + exc = RuntimeError("database exploded") + + kind, code, severity, expected, message, extra = classify_stream_exception( + exc, + flow_label="chat", + ) + + assert kind == "server_error" + assert code == "SERVER_ERROR" + assert severity == "error" + assert expected is False + assert message == "Error during chat: database exploded" + assert extra is None + diff --git a/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx b/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx index f048376cc..0c4fa63ec 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx @@ -613,6 +613,18 @@ export default function NewChatPage() { return; } + if (normalized.channel === "inline") { + if (normalized.assistantMessage) { + await persistAssistantErrorMessage({ + threadId, + assistantMsgId, + text: normalized.assistantMessage, + }); + } + toast.error(normalized.userMessage); + return; + } + toast.error(normalized.userMessage); }, [currentUser?.id, persistAssistantErrorMessage, searchSpaceId, setPremiumAlertForThread] diff --git a/surfsense_web/components/free-chat/free-chat-page.tsx b/surfsense_web/components/free-chat/free-chat-page.tsx index b28b1e0a1..8d5215fca 100644 --- a/surfsense_web/components/free-chat/free-chat-page.tsx +++ b/surfsense_web/components/free-chat/free-chat-page.tsx @@ -63,6 +63,21 @@ function normalizeFreeChatErrorMessage(error: unknown): string { if (code === "THREAD_BUSY") { return "A previous response is still stopping. Please try again in a moment."; } + if (code === "MODEL_AUTH_FAILED") { + return "This model’s API key is invalid or expired. Switch models, or update the API key."; + } + if (code === "MODEL_NOT_FOUND") { + return "This model is unavailable or no longer exists. Please switch models."; + } + if (code === "MODEL_CONTEXT_LIMIT") { + return "This request is too large for the selected model. Reduce the input or switch models."; + } + if (code === "MODEL_PROVIDER_UNAVAILABLE") { + return "The selected model provider is temporarily unavailable. Please try again or switch models."; + } + if (code === "RATE_LIMITED") { + return "This model is temporarily rate-limited. Please try again in a few seconds or switch models."; + } return error.message || "An unexpected error occurred"; } @@ -154,7 +169,7 @@ export function FreeChatPage() { assistantMsgId: string, signal: AbortSignal, turnstileToken: string | null - ): Promise<"captcha" | void> => { + ): Promise<"captcha" | undefined> => { const reqBody: Record = { model_slug: modelSlug, messages: messageHistory, diff --git a/surfsense_web/lib/chat/chat-error-classifier.ts b/surfsense_web/lib/chat/chat-error-classifier.ts index 1c67d59a1..92924f0f7 100644 --- a/surfsense_web/lib/chat/chat-error-classifier.ts +++ b/surfsense_web/lib/chat/chat-error-classifier.ts @@ -5,6 +5,10 @@ export type ChatErrorKind = | "thread_busy" | "send_failed_pre_accept" | "auth_expired" + | "model_auth_failed" + | "model_not_found" + | "model_context_limit" + | "model_provider_unavailable" | "rate_limited" | "network_offline" | "stream_interrupted" @@ -14,7 +18,7 @@ export type ChatErrorKind = | "server_error" | "unknown"; -export type ChatErrorChannel = "pinned_inline" | "toast" | "silent"; +export type ChatErrorChannel = "pinned_inline" | "inline" | "toast" | "silent"; export type ChatTelemetryEvent = "chat_blocked" | "chat_error"; export type ChatErrorSeverity = "info" | "warn" | "error"; @@ -206,6 +210,66 @@ export function classifyChatError(input: RawChatErrorInput): NormalizedChatError }; } + if (errorCode === "MODEL_AUTH_FAILED") { + return { + kind: "model_auth_failed", + channel: "toast", + severity: "warn", + telemetryEvent: "chat_blocked", + isExpected: true, + userMessage: + "This model’s API key is invalid or expired. Switch models, or update the API key.", + rawMessage, + errorCode: errorCode ?? "MODEL_AUTH_FAILED", + details: { flow: input.flow, providerErrorType }, + }; + } + + if (errorCode === "MODEL_NOT_FOUND") { + return { + kind: "model_not_found", + channel: "toast", + severity: "warn", + telemetryEvent: "chat_blocked", + isExpected: true, + userMessage: + "This model is unavailable or no longer exists. Switch to another model and try again.", + rawMessage, + errorCode: errorCode ?? "MODEL_NOT_FOUND", + details: { flow: input.flow, providerErrorType }, + }; + } + + if (errorCode === "MODEL_CONTEXT_LIMIT") { + return { + kind: "model_context_limit", + channel: "toast", + severity: "warn", + telemetryEvent: "chat_blocked", + isExpected: true, + userMessage: + "This request is too large for the selected model. Reduce the input or switch models.", + rawMessage, + errorCode: errorCode ?? "MODEL_CONTEXT_LIMIT", + details: { flow: input.flow, providerErrorType }, + }; + } + + if (errorCode === "MODEL_PROVIDER_UNAVAILABLE") { + return { + kind: "model_provider_unavailable", + channel: "toast", + severity: "warn", + telemetryEvent: "chat_blocked", + isExpected: true, + userMessage: + "The selected model provider is temporarily unavailable. Please try again or switch models.", + rawMessage, + errorCode: errorCode ?? "MODEL_PROVIDER_UNAVAILABLE", + details: { flow: input.flow, providerErrorType }, + }; + } + if (errorCode === "RATE_LIMITED" || providerTypeNormalized === "rate_limit_error") { return { kind: "rate_limited", diff --git a/surfsense_web/lib/chat/chat-request-errors.ts b/surfsense_web/lib/chat/chat-request-errors.ts index e0dfb3cc4..c86c72d66 100644 --- a/surfsense_web/lib/chat/chat-request-errors.ts +++ b/surfsense_web/lib/chat/chat-request-errors.ts @@ -91,6 +91,10 @@ export function tagPreAcceptSendFailure(error: unknown): unknown { "TURN_CANCELLING", "AUTH_EXPIRED", "UNAUTHORIZED", + "MODEL_AUTH_FAILED", + "MODEL_NOT_FOUND", + "MODEL_CONTEXT_LIMIT", + "MODEL_PROVIDER_UNAVAILABLE", "RATE_LIMITED", "NETWORK_ERROR", "STREAM_PARSE_ERROR", From ad404b2dbc85ad44a25deea5c0b1979785379911 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Fri, 12 Jun 2026 05:15:15 +0530 Subject: [PATCH 28/59] refactor(icons): replace Workflow icon with Clock3 across automation components --- .../automations/components/automations-empty-state.tsx | 4 ++-- .../automations/components/automations-table.tsx | 4 ++-- .../components/layout/providers/LayoutDataProvider.tsx | 4 ++-- surfsense_web/components/new-chat/chat-example-prompts.tsx | 4 ++-- .../components/tool-ui/automation/create-automation.tsx | 6 +++--- surfsense_web/contracts/enums/toolIcons.tsx | 4 ++-- 6 files changed, 13 insertions(+), 13 deletions(-) diff --git a/surfsense_web/app/dashboard/[search_space_id]/automations/components/automations-empty-state.tsx b/surfsense_web/app/dashboard/[search_space_id]/automations/components/automations-empty-state.tsx index b2e7b2532..70d9990f8 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/automations/components/automations-empty-state.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/automations/components/automations-empty-state.tsx @@ -1,5 +1,5 @@ "use client"; -import { Workflow } from "lucide-react"; +import { Clock3 } from "lucide-react"; import Link from "next/link"; import { Button } from "@/components/ui/button"; @@ -18,7 +18,7 @@ export function AutomationsEmptyState({ searchSpaceId, canCreate }: AutomationsE return (
- +

No automations yet

diff --git a/surfsense_web/app/dashboard/[search_space_id]/automations/components/automations-table.tsx b/surfsense_web/app/dashboard/[search_space_id]/automations/components/automations-table.tsx index 8314a5179..727636b43 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/automations/components/automations-table.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/automations/components/automations-table.tsx @@ -1,5 +1,5 @@ "use client"; -import { CalendarDays, Info, Workflow } from "lucide-react"; +import { CalendarDays, Clock3, Info } from "lucide-react"; import { Table, TableBody, TableHead, TableHeader, TableRow } from "@/components/ui/table"; import type { AutomationSummary } from "@/contracts/types/automation.types"; import { AutomationRow } from "./automation-row"; @@ -31,7 +31,7 @@ export function AutomationsTable({ - + Name diff --git a/surfsense_web/components/layout/providers/LayoutDataProvider.tsx b/surfsense_web/components/layout/providers/LayoutDataProvider.tsx index 549e6e7d7..5c62f6a7d 100644 --- a/surfsense_web/components/layout/providers/LayoutDataProvider.tsx +++ b/surfsense_web/components/layout/providers/LayoutDataProvider.tsx @@ -2,7 +2,7 @@ import { useQuery } from "@tanstack/react-query"; import { useAtom, useAtomValue, useSetAtom } from "jotai"; -import { AlertTriangle, Inbox, LibraryBig, Workflow } from "lucide-react"; +import { AlertTriangle, Clock3, Inbox, LibraryBig } from "lucide-react"; import { useParams, usePathname, useRouter } from "next/navigation"; import { useTranslations } from "next-intl"; import { useTheme } from "next-themes"; @@ -342,7 +342,7 @@ export function LayoutDataProvider({ searchSpaceId, children }: LayoutDataProvid { title: "Automations", url: `/dashboard/${searchSpaceId}/automations`, - icon: Workflow, + icon: Clock3, isActive: isAutomationsActive, }, isMobile diff --git a/surfsense_web/components/new-chat/chat-example-prompts.tsx b/surfsense_web/components/new-chat/chat-example-prompts.tsx index 98d95b98b..4fdc32a92 100644 --- a/surfsense_web/components/new-chat/chat-example-prompts.tsx +++ b/surfsense_web/components/new-chat/chat-example-prompts.tsx @@ -1,12 +1,12 @@ "use client"; import { + Clock3, FilePlus2, Search, Settings2, type LucideIcon, WandSparkles, - Workflow, X, } from "lucide-react"; import { memo, useCallback, useState } from "react"; @@ -22,7 +22,7 @@ interface ChatExamplePromptsProps { const CATEGORY_ICONS: Record = { search: Search, create: FilePlus2, - automate: Workflow, + automate: Clock3, tools: Settings2, }; diff --git a/surfsense_web/components/tool-ui/automation/create-automation.tsx b/surfsense_web/components/tool-ui/automation/create-automation.tsx index 2a7d09f53..644ccd822 100644 --- a/surfsense_web/components/tool-ui/automation/create-automation.tsx +++ b/surfsense_web/components/tool-ui/automation/create-automation.tsx @@ -2,7 +2,7 @@ import type { ToolCallMessagePartProps } from "@assistant-ui/react"; import { useAtomValue } from "jotai"; -import { AlertCircle, CornerDownLeftIcon, ExternalLink, Pencil, Workflow } from "lucide-react"; +import { AlertCircle, Clock3, CornerDownLeftIcon, ExternalLink, Pencil } from "lucide-react"; import Link from "next/link"; import { useCallback, useEffect, useMemo, useRef, useState } from "react"; import { @@ -211,7 +211,7 @@ function ApprovalCard({ args, interruptData, onDecision }: ApprovalCardProps) {

- +

{phase === "rejected" @@ -404,7 +404,7 @@ function SavedCard({ result }: { result: SavedResult }) { return (

- +

Automation saved

{result.name}

diff --git a/surfsense_web/contracts/enums/toolIcons.tsx b/surfsense_web/contracts/enums/toolIcons.tsx index 494c0eaee..496c26577 100644 --- a/surfsense_web/contracts/enums/toolIcons.tsx +++ b/surfsense_web/contracts/enums/toolIcons.tsx @@ -1,6 +1,7 @@ import { Brain, Calendar, + Clock3, FileEdit, FilePlus, FileText, @@ -24,7 +25,6 @@ import { SearchCheck, Send, Trash2, - Workflow, Wrench, } from "lucide-react"; @@ -47,7 +47,7 @@ const TOOL_ICONS: Record = { scrape_webpage: ScanLine, web_search: Globe, // Automations - create_automation: Workflow, + create_automation: Clock3, // Memory update_memory: Brain, // Filesystem (built-in deepagent + middleware) From ced1bb85edc778e710ff133fe5265df089a40a11 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Fri, 12 Jun 2026 09:43:56 +0530 Subject: [PATCH 29/59] feat(model-connections): implement bulk model update endpoint and related schema changes --- .../app/routes/model_connections_routes.py | 31 +- surfsense_backend/app/schemas/__init__.py | 1 + .../app/schemas/model_connections.py | 6 + .../model-connections-mutation.atoms.ts | 12 + .../settings/model-connections-settings.tsx | 626 +++++++++++++----- .../types/model-connections.types.ts | 7 + .../lib/apis/model-connections-api.service.ts | 23 +- 7 files changed, 538 insertions(+), 168 deletions(-) diff --git a/surfsense_backend/app/routes/model_connections_routes.py b/surfsense_backend/app/routes/model_connections_routes.py index ecb86711e..730c68565 100644 --- a/surfsense_backend/app/routes/model_connections_routes.py +++ b/surfsense_backend/app/routes/model_connections_routes.py @@ -1,7 +1,7 @@ import logging from fastapi import APIRouter, Depends, HTTPException -from sqlalchemy import select +from sqlalchemy import select, update from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import selectinload @@ -25,6 +25,7 @@ from app.schemas import ( ModelRead, ModelRolesRead, ModelRolesUpdate, + ModelsBulkUpdate, ModelUpdate, VerifyConnectionResponse, ) @@ -62,6 +63,7 @@ def _connection_read(conn: Connection | dict, models: list[Model | dict] | None id=conn.id, provider=conn.provider, base_url=conn.base_url, + api_key=conn.api_key, extra=conn.extra or {}, scope=conn.scope, search_space_id=conn.search_space_id, @@ -351,6 +353,33 @@ async def add_manual_model( return _model_read(model) +@router.patch("/model-connections/{connection_id}/models", response_model=list[ModelRead]) +async def bulk_update_models( + connection_id: int, + data: ModelsBulkUpdate, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +): + conn = await _load_connection(session, connection_id) + await _assert_connection_access(session, user, conn, Permission.LLM_CONFIGS_UPDATE.value) + + model_ids = set(data.model_ids) + await session.execute( + update(Model) + .where(Model.connection_id == connection_id, Model.id.in_(model_ids)) + .values(enabled=data.enabled) + ) + await session.commit() + session.expire_all() + + result = await session.execute( + select(Model) + .where(Model.connection_id == connection_id, Model.id.in_(model_ids)) + .order_by(Model.id) + ) + return [_model_read(model) for model in result.scalars().all()] + + @router.put("/models/{model_id}", response_model=ModelRead) async def update_model( model_id: int, diff --git a/surfsense_backend/app/schemas/__init__.py b/surfsense_backend/app/schemas/__init__.py index 8ac7c5bbb..55e712f12 100644 --- a/surfsense_backend/app/schemas/__init__.py +++ b/surfsense_backend/app/schemas/__init__.py @@ -53,6 +53,7 @@ from .model_connections import ( ModelRead, ModelRolesRead, ModelRolesUpdate, + ModelsBulkUpdate, ModelUpdate, VerifyConnectionResponse, ) diff --git a/surfsense_backend/app/schemas/model_connections.py b/surfsense_backend/app/schemas/model_connections.py index c081a193d..0b03c7fab 100644 --- a/surfsense_backend/app/schemas/model_connections.py +++ b/surfsense_backend/app/schemas/model_connections.py @@ -32,6 +32,7 @@ class ConnectionRead(BaseModel): id: int provider: str base_url: str | None = None + api_key: str | None = None extra: dict[str, Any] = Field(default_factory=dict) scope: ConnectionScope | str search_space_id: int | None = None @@ -87,6 +88,11 @@ class ModelUpdate(BaseModel): capabilities_override: dict[str, Any] | None = None +class ModelsBulkUpdate(BaseModel): + model_ids: list[int] = Field(..., min_length=1, max_length=1000) + enabled: bool + + class ModelProviderRead(BaseModel): provider: str transport: str diff --git a/surfsense_web/atoms/model-connections/model-connections-mutation.atoms.ts b/surfsense_web/atoms/model-connections/model-connections-mutation.atoms.ts index 101bad1b5..fee3b95ba 100644 --- a/surfsense_web/atoms/model-connections/model-connections-mutation.atoms.ts +++ b/surfsense_web/atoms/model-connections/model-connections-mutation.atoms.ts @@ -6,6 +6,7 @@ import type { ModelCreateRequest, ModelRead, ModelRoles, + ModelsBulkUpdateRequest, ModelUpdateRequest, VerifyConnectionResponse, } from "@/contracts/types/model-connections.types"; @@ -127,6 +128,17 @@ export const updateModelMutationAtom = atomWithMutation((get) => { }; }); +export const bulkUpdateModelsMutationAtom = atomWithMutation((get) => { + const searchSpaceId = Number(get(activeSearchSpaceIdAtom)); + return { + mutationKey: ["models", "bulk-update"], + mutationFn: ({ connectionId, data }: { connectionId: number; data: ModelsBulkUpdateRequest }) => + modelConnectionsApiService.bulkUpdateModels(connectionId, data), + onSuccess: () => invalidateModelConnections(searchSpaceId), + onError: (error: Error) => toast.error(error.message || "Failed to update models"), + }; +}); + export const testModelMutationAtom = atomWithMutation((get) => { const searchSpaceId = Number(get(activeSearchSpaceIdAtom)); return { diff --git a/surfsense_web/components/settings/model-connections-settings.tsx b/surfsense_web/components/settings/model-connections-settings.tsx index 0e541548b..9112cfe64 100644 --- a/surfsense_web/components/settings/model-connections-settings.tsx +++ b/surfsense_web/components/settings/model-connections-settings.tsx @@ -1,14 +1,24 @@ "use client"; import { useAtom, useAtomValue } from "jotai"; -import { CheckCircle2, PlugZap, Plus, RefreshCcw, Trash2, XCircle } from "lucide-react"; +import { + Check, + CheckCircle2, + ChevronsUpDown, + Eye, + EyeOff, + RefreshCcw, + Settings, + Trash2, + XCircle, +} from "lucide-react"; import { useState } from "react"; import { addManualModelMutationAtom, + bulkUpdateModelsMutationAtom, createModelConnectionMutationAtom, deleteModelConnectionMutationAtom, discoverConnectionModelsMutationAtom, - testModelMutationAtom, updateModelConnectionMutationAtom, updateModelMutationAtom, updateModelRolesMutationAtom, @@ -20,11 +30,41 @@ import { modelProvidersAtom, modelRolesAtom, } from "@/atoms/model-connections/model-connections-query.atoms"; +import { + AlertDialog, + AlertDialogAction, + AlertDialogCancel, + AlertDialogContent, + AlertDialogDescription, + AlertDialogFooter, + AlertDialogHeader, + AlertDialogTitle, + AlertDialogTrigger, +} from "@/components/ui/alert-dialog"; import { Badge } from "@/components/ui/badge"; import { Button } from "@/components/ui/button"; import { Card, CardContent, CardDescription, CardHeader, CardTitle } from "@/components/ui/card"; +import { Checkbox } from "@/components/ui/checkbox"; +import { + Command, + CommandEmpty, + CommandGroup, + CommandInput, + CommandItem, + CommandList, +} from "@/components/ui/command"; +import { + Dialog, + DialogContent, + DialogDescription, + DialogFooter, + DialogHeader, + DialogTitle, + DialogTrigger, +} from "@/components/ui/dialog"; import { Input } from "@/components/ui/input"; import { Label } from "@/components/ui/label"; +import { Popover, PopoverContent, PopoverTrigger } from "@/components/ui/popover"; import { Select, SelectContent, @@ -32,8 +72,14 @@ import { SelectTrigger, SelectValue, } from "@/components/ui/select"; -import type { ConnectionRead, ModelRead } from "@/contracts/types/model-connections.types"; +import { Separator } from "@/components/ui/separator"; +import type { + ConnectionRead, + ConnectionUpdateRequest, + ModelRead, +} from "@/contracts/types/model-connections.types"; import { getProviderIcon } from "@/lib/provider-icons"; +import { cn } from "@/lib/utils"; // Free-text URL hints (datalist), mirroring OpenWebUI. These never restrict // what the user can type — any OpenAI-compatible endpoint works. @@ -69,6 +115,67 @@ const MODEL_CAPABILITY_FILTERS: { key: ModelCapabilityFilter; label: string }[] { key: "image_gen", label: "Image" }, ]; +function UrlSuggestionCombobox({ + value, + onChange, + placeholder, +}: { + value: string; + onChange: (value: string) => void; + placeholder: string; +}) { + const [open, setOpen] = useState(false); + + return ( + + + + + + + + + + Use the custom URL you typed + + + {URL_SUGGESTIONS.map((url) => ( + { + onChange(url); + setOpen(false); + }} + > + + {url} + + ))} + + + + + + ); +} + function StatusBadge({ connection }: { connection: ConnectionRead }) { if (connection.last_status === "OK") { return ( @@ -105,11 +212,15 @@ function ConnectionCard({ connection }: { connection: ConnectionRead }) { const deleteConnection = useAtomValue(deleteModelConnectionMutationAtom); const addManualModel = useAtomValue(addManualModelMutationAtom); const updateModel = useAtomValue(updateModelMutationAtom); - const testModel = useAtomValue(testModelMutationAtom); + const bulkUpdateModels = useAtomValue(bulkUpdateModelsMutationAtom); const allowlist = Array.isArray(connection.extra?.model_ids) ? (connection.extra.model_ids as string[]) : []; + const [isSettingsOpen, setIsSettingsOpen] = useState(false); + const [baseUrlDraft, setBaseUrlDraft] = useState(connection.base_url ?? ""); + const [apiKeyDraft, setApiKeyDraft] = useState(""); + const [showApiKey, setShowApiKey] = useState(false); const [allowlistText, setAllowlistText] = useState(allowlist.join(", ")); const [manualModelId, setManualModelId] = useState(""); const [modelFilter, setModelFilter] = useState(null); @@ -122,6 +233,38 @@ function ConnectionCard({ connection }: { connection: ConnectionRead }) { const filteredModels = modelFilter ? connection.models.filter((model) => capability(model, modelFilter)) : connection.models; + const allFilteredModelsEnabled = + filteredModels.length > 0 && filteredModels.every((model) => model.enabled); + const hasConnectionChanges = + baseUrlDraft.trim() !== (connection.base_url ?? "") || + apiKeyDraft.trim() !== (connection.api_key ?? ""); + + function handleSettingsOpenChange(open: boolean) { + setIsSettingsOpen(open); + if (open) { + setBaseUrlDraft(connection.base_url ?? ""); + setApiKeyDraft(connection.api_key ?? ""); + setShowApiKey(false); + setAllowlistText(allowlist.join(", ")); + } + } + + function saveConnectionSettings() { + const data: ConnectionUpdateRequest = { + base_url: baseUrlDraft.trim() || null, + }; + + if (apiKeyDraft.trim() !== (connection.api_key ?? "")) { + data.api_key = apiKeyDraft.trim() || null; + } + + updateConnection.mutate( + { id: connection.id, data }, + { + onSuccess: () => setApiKeyDraft(""), + } + ); + } function saveAllowlist() { const ids = allowlistText @@ -144,170 +287,321 @@ function ConnectionCard({ connection }: { connection: ConnectionRead }) { } function deleteCurrentConnection() { - const confirmed = window.confirm( - `Delete the ${providerLabel} connection and all of its models? This cannot be undone.` - ); - if (!confirmed) return; deleteConnection.mutate(connection.id); } + function toggleFilteredModels() { + const nextEnabled = !allFilteredModelsEnabled; + const modelIds = filteredModels + .filter((model) => model.enabled !== nextEnabled) + .map((model) => model.id); + + if (modelIds.length === 0) return; + + bulkUpdateModels.mutate({ + connectionId: connection.id, + data: { model_ids: modelIds, enabled: nextEnabled }, + }); + } + return ( -
-
-
-
+
+
+
+
{getProviderIcon(providerLabel, { className: "size-4" })} - {providerLabel} + {providerLabel} + {connection.scope === "GLOBAL" ? ( + + Default + + ) : null}
-
+
{connection.base_url || "Provider default endpoint"}
-
+
- - - -
-
- - {connection.last_status && connection.last_status !== "OK" ? ( -

- {connection.last_error || "Could not list models."} Chat may still work — add model IDs - manually below. -

- ) : null} - - {!isLocal ? ( -
- -
- setAllowlistText(event.target.value)} - placeholder="Comma-separated, e.g. anthropic/claude-sonnet-4-5, google/gemini-2.5-pro" - /> - -
-

- Leave empty to discover all models. Recommended for providers with large catalogs (e.g. - OpenRouter). -

-
- ) : null} - -
- setManualModelId(event.target.value)} - onKeyDown={(event) => { - if (event.key === "Enter") { - event.preventDefault(); - addModel(); - } - }} - placeholder="Add a model ID manually (for providers without /models)" - /> - -
- - {connection.models.length > 0 ? ( -
- Filter models - {MODEL_CAPABILITY_FILTERS.map((filter) => { - const count = connection.models.filter((model) => capability(model, filter.key)).length; - const isActive = modelFilter === filter.key; - - return ( - - ); - })} -
- ) : null} + + + +
+ {getProviderIcon(providerLabel, { className: "size-5" })} +
+ + Configure {providerLabel} + + + Manage credentials and choose which models are available from this provider. + +
+
+
-
- {filteredModels.length === 0 && modelFilter ? ( -
- No {MODEL_CAPABILITY_FILTERS.find((filter) => filter.key === modelFilter)?.label.toLowerCase()}{" "} - models found on this connection. -
- ) : null} - {filteredModels.map((model) => ( -
-
-
- {getProviderIcon(providerLabel, { className: "size-4" })} - {modelLabel(model)} - {model.source === "MANUAL" ? ( - - manual - - ) : null} +
+
+
+ + +

+ Leave empty to use the provider default endpoint. +

+
+ +
+ +
+ setApiKeyDraft(event.target.value)} + placeholder={connection.has_api_key ? "Saved API key" : "Paste an API key"} + type={showApiKey ? "text" : "password"} + className="pr-11" + /> + +
+
+ + {!isLocal ? ( +
+ +
+ setAllowlistText(event.target.value)} + placeholder="Comma-separated, e.g. anthropic/claude-sonnet-4-5, google/gemini-2.5-pro" + /> + +
+

+ Leave empty to discover all models. Recommended for providers with large + catalogs. +

+
+ ) : null} + + + +
+
+
+
Models
+

+ Select models to make available for this provider. +

+
+
+ + +
+
+ +
+ setManualModelId(event.target.value)} + onKeyDown={(event) => { + if (event.key === "Enter") { + event.preventDefault(); + addModel(); + } + }} + placeholder="Add a model ID manually" + /> + +
+ + {connection.models.length > 0 ? ( +
+ + Filter models + + {MODEL_CAPABILITY_FILTERS.map((filter) => { + const count = connection.models.filter((model) => + capability(model, filter.key) + ).length; + const isActive = modelFilter === filter.key; + + return ( + + ); + })} +
+ ) : null} + +
+ {connection.models.length === 0 ? ( +
+ No models yet. Use the refresh button to discover models or add one + manually. +
+ ) : null} + {filteredModels.length === 0 && modelFilter ? ( +
+ No{" "} + {MODEL_CAPABILITY_FILTERS.find( + (filter) => filter.key === modelFilter + )?.label.toLowerCase()}{" "} + models found on this connection. +
+ ) : null} +
+ {filteredModels.map((model) => ( +
+ + updateModel.mutate({ + id: model.id, + data: { enabled: checked === true }, + }) + } + disabled={updateModel.isPending} + /> +
+
+ {modelLabel(model)} + {model.source === "MANUAL" ? ( + + manual + + ) : null} +
+
+ {["chat", "vision", "image_gen"] + .filter((key) => + capability(model, key as "chat" | "vision" | "image_gen") + ) + .join(", ") || "No discovered capabilities"} +
+
+
+ ))} +
+
+
+ + {connection.last_status && connection.last_status !== "OK" ? ( +

+ {connection.last_error || "Could not list models."} Chat may still work; add + model IDs manually if discovery is unavailable. +

+ ) : null} +
-
- {["chat", "vision", "image_gen"] - .filter((key) => capability(model, key as "chat" | "vision" | "image_gen")) - .join(", ") || "No discovered capabilities"} -
-
-
- + + + + + + + + + -
-
- ))} + + + + Delete this provider? + + {providerLabel} and all of + its models will be removed from this search space. This cannot be undone. + + + + Cancel + + Delete + + + + +
); @@ -394,19 +688,13 @@ export function ModelConnectionsSettings({ searchSpaceId }: { searchSpaceId: num
- setBaseUrl(event.target.value)} + onChange={setBaseUrl} placeholder={ isOllama ? "http://host.docker.internal:11434" : "https://api.example.com/v1" } - list="model-conn-url-suggestions" /> - - {URL_SUGGESTIONS.map((url) => ( -
@@ -425,7 +713,7 @@ export function ModelConnectionsSettings({ searchSpaceId }: { searchSpaceId: num Boolean(selectedProvider?.base_url_required && !baseUrl.trim()) } > - Add + Add
@@ -439,11 +727,17 @@ export function ModelConnectionsSettings({ searchSpaceId }: { searchSpaceId: num

-
- {connections.map((connection) => ( - - ))} -
+ {connections.length > 0 ? ( +
+ +

Available Providers

+
+ {connections.map((connection) => ( + + ))} +
+
+ ) : null} diff --git a/surfsense_web/contracts/types/model-connections.types.ts b/surfsense_web/contracts/types/model-connections.types.ts index a34687d74..c75f4c90a 100644 --- a/surfsense_web/contracts/types/model-connections.types.ts +++ b/surfsense_web/contracts/types/model-connections.types.ts @@ -26,6 +26,7 @@ export const connectionRead = z.object({ id: z.number(), provider: z.string(), base_url: z.string().nullable().optional(), + api_key: z.string().nullable().optional(), extra: z.record(z.string(), z.any()).default({}), scope: z.union([connectionScopeEnum, z.string()]), search_space_id: z.number().nullable().optional(), @@ -73,6 +74,11 @@ export const modelUpdateRequest = z.object({ capabilities_override: z.record(z.string(), z.any()).optional(), }); +export const modelsBulkUpdateRequest = z.object({ + model_ids: z.array(z.number()).min(1).max(1000), + enabled: z.boolean(), +}); + export const verifyConnectionResponse = z.object({ status: z.string(), ok: z.boolean(), @@ -107,6 +113,7 @@ export type ConnectionCreateRequest = z.infer; export type ConnectionUpdateRequest = z.infer; export type ModelCreateRequest = z.infer; export type ModelUpdateRequest = z.infer; +export type ModelsBulkUpdateRequest = z.infer; export type ModelRoles = z.infer; export type VerifyConnectionResponse = z.infer; export type ModelProviderRead = z.infer; diff --git a/surfsense_web/lib/apis/model-connections-api.service.ts b/surfsense_web/lib/apis/model-connections-api.service.ts index 12ad8e0d2..bd5aa1309 100644 --- a/surfsense_web/lib/apis/model-connections-api.service.ts +++ b/surfsense_web/lib/apis/model-connections-api.service.ts @@ -10,12 +10,14 @@ import { type ModelProviderRead, type ModelRead, type ModelRoles, + type ModelsBulkUpdateRequest, type ModelUpdateRequest, modelCreateRequest, - modelProviderListResponse, modelListResponse, + modelProviderListResponse, modelRead, modelRoles, + modelsBulkUpdateRequest, modelUpdateRequest, type VerifyConnectionResponse, verifyConnectionResponse, @@ -97,6 +99,25 @@ class ModelConnectionsApiService { }); }; + bulkUpdateModels = async ( + connectionId: number, + request: ModelsBulkUpdateRequest + ): Promise => { + const parsed = modelsBulkUpdateRequest.safeParse(request); + if (!parsed.success) { + throw new ValidationError(parsed.error.issues.map((issue) => issue.message).join(", ")); + } + return baseApiService.request( + `/api/v1/model-connections/${connectionId}/models`, + modelListResponse, + { + method: "PATCH", + headers: { "Content-Type": "application/json" }, + body: parsed.data, + } + ); + }; + testModel = async (id: number): Promise => { return baseApiService.post(`/api/v1/models/${id}/test`, verifyConnectionResponse); }; From 356f0e56c5d5b94843436328b0ca5329299800a4 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Fri, 12 Jun 2026 22:04:44 +0530 Subject: [PATCH 30/59] feat(model-connections): add provider-specific connection forms and shared components --- .../search-space-settings/layout-shell.tsx | 4 +- .../free-chat/free-model-selector.tsx | 4 +- .../components/new-chat/model-selector.tsx | 4 +- .../settings/model-connections-settings.tsx | 762 +++++------------- .../model-connections/azure-connect-form.tsx | 59 ++ .../bedrock-connect-form.tsx | 134 +++ .../model-connections/connect-fields.tsx | 83 ++ .../connection-settings-dialog.tsx | 249 ++++++ .../default-connect-form.tsx | 51 ++ .../settings/model-connections/model-utils.ts | 25 + .../models-selection-panel.tsx | 196 +++++ .../provider-connect-dialog.tsx | 154 ++++ .../model-connections/provider-metadata.tsx | 139 ++++ .../model-connections/vertex-connect-form.tsx | 127 +++ surfsense_web/lib/provider-icons.tsx | 6 +- 15 files changed, 1423 insertions(+), 574 deletions(-) create mode 100644 surfsense_web/components/settings/model-connections/azure-connect-form.tsx create mode 100644 surfsense_web/components/settings/model-connections/bedrock-connect-form.tsx create mode 100644 surfsense_web/components/settings/model-connections/connect-fields.tsx create mode 100644 surfsense_web/components/settings/model-connections/connection-settings-dialog.tsx create mode 100644 surfsense_web/components/settings/model-connections/default-connect-form.tsx create mode 100644 surfsense_web/components/settings/model-connections/model-utils.ts create mode 100644 surfsense_web/components/settings/model-connections/models-selection-panel.tsx create mode 100644 surfsense_web/components/settings/model-connections/provider-connect-dialog.tsx create mode 100644 surfsense_web/components/settings/model-connections/provider-metadata.tsx create mode 100644 surfsense_web/components/settings/model-connections/vertex-connect-form.tsx diff --git a/surfsense_web/app/dashboard/[search_space_id]/search-space-settings/layout-shell.tsx b/surfsense_web/app/dashboard/[search_space_id]/search-space-settings/layout-shell.tsx index 9d9045004..d30ea8a3a 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/search-space-settings/layout-shell.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/search-space-settings/layout-shell.tsx @@ -2,8 +2,8 @@ import { BookText, - Bot, CircleUser, + Cpu, Earth, UserKey, } from "lucide-react"; @@ -54,7 +54,7 @@ export function SearchSpaceSettingsLayoutShell({ { value: "models" as const, label: t("nav_models"), - icon: , + icon: , }, { value: "team-roles" as const, diff --git a/surfsense_web/components/free-chat/free-model-selector.tsx b/surfsense_web/components/free-chat/free-model-selector.tsx index 9bf4ecee5..d04bca8a2 100644 --- a/surfsense_web/components/free-chat/free-model-selector.tsx +++ b/surfsense_web/components/free-chat/free-model-selector.tsx @@ -1,6 +1,6 @@ "use client"; -import { Bot, Check, ChevronDown } from "lucide-react"; +import { Check, ChevronDown, Cpu } from "lucide-react"; import { useRouter } from "next/navigation"; import { useCallback, useEffect, useMemo, useState } from "react"; import { Badge } from "@/components/ui/badge"; @@ -82,7 +82,7 @@ export function FreeModelSelector({ className }: { className?: string }) { ) : ( <> - + Select Model )} diff --git a/surfsense_web/components/new-chat/model-selector.tsx b/surfsense_web/components/new-chat/model-selector.tsx index 6850096d6..dd1bcb431 100644 --- a/surfsense_web/components/new-chat/model-selector.tsx +++ b/surfsense_web/components/new-chat/model-selector.tsx @@ -1,7 +1,7 @@ "use client"; import { useAtom, useAtomValue } from "jotai"; -import { Bot, Check, ChevronDown, ImageOff, Search, Settings2, Zap } from "lucide-react"; +import { Check, ChevronDown, Cpu, ImageOff, Search, Settings2, Zap } from "lucide-react"; import { useMemo, useState } from "react"; import { updateModelRolesMutationAtom } from "@/atoms/model-connections/model-connections-mutation.atoms"; import { @@ -222,7 +222,7 @@ export function ModelSelector({ {selected ? ( getProviderIcon(selected.provider, { className: "size-4" }) ) : ( - + )} {selected ? modelName(selected) : "Auto"} diff --git a/surfsense_web/components/settings/model-connections-settings.tsx b/surfsense_web/components/settings/model-connections-settings.tsx index 9112cfe64..2e15ce2e9 100644 --- a/surfsense_web/components/settings/model-connections-settings.tsx +++ b/surfsense_web/components/settings/model-connections-settings.tsx @@ -1,17 +1,7 @@ "use client"; import { useAtom, useAtomValue } from "jotai"; -import { - Check, - CheckCircle2, - ChevronsUpDown, - Eye, - EyeOff, - RefreshCcw, - Settings, - Trash2, - XCircle, -} from "lucide-react"; +import { CheckCircle2, Trash2, XCircle } from "lucide-react"; import { useState } from "react"; import { addManualModelMutationAtom, @@ -19,10 +9,8 @@ import { createModelConnectionMutationAtom, deleteModelConnectionMutationAtom, discoverConnectionModelsMutationAtom, - updateModelConnectionMutationAtom, updateModelMutationAtom, updateModelRolesMutationAtom, - verifyModelConnectionMutationAtom, } from "@/atoms/model-connections/model-connections-mutation.atoms"; import { globalModelConnectionsAtom, @@ -44,27 +32,7 @@ import { import { Badge } from "@/components/ui/badge"; import { Button } from "@/components/ui/button"; import { Card, CardContent, CardDescription, CardHeader, CardTitle } from "@/components/ui/card"; -import { Checkbox } from "@/components/ui/checkbox"; -import { - Command, - CommandEmpty, - CommandGroup, - CommandInput, - CommandItem, - CommandList, -} from "@/components/ui/command"; -import { - Dialog, - DialogContent, - DialogDescription, - DialogFooter, - DialogHeader, - DialogTitle, - DialogTrigger, -} from "@/components/ui/dialog"; -import { Input } from "@/components/ui/input"; import { Label } from "@/components/ui/label"; -import { Popover, PopoverContent, PopoverTrigger } from "@/components/ui/popover"; import { Select, SelectContent, @@ -73,108 +41,16 @@ import { SelectValue, } from "@/components/ui/select"; import { Separator } from "@/components/ui/separator"; -import type { - ConnectionRead, - ConnectionUpdateRequest, - ModelRead, -} from "@/contracts/types/model-connections.types"; -import { getProviderIcon } from "@/lib/provider-icons"; -import { cn } from "@/lib/utils"; - -// Free-text URL hints (datalist), mirroring OpenWebUI. These never restrict -// what the user can type — any OpenAI-compatible endpoint works. -const URL_SUGGESTIONS = [ - "https://api.openai.com/v1", - "https://api.anthropic.com/v1", - "https://openrouter.ai/api/v1", - "https://generativelanguage.googleapis.com/v1beta/openai", - "https://api.groq.com/openai/v1", - "https://api.mistral.ai/v1", - "https://api.deepseek.com/v1", - "https://api.x.ai/v1", - "http://host.docker.internal:11434", - "http://host.docker.internal:1234/v1", - "http://host.docker.internal:8000/v1", -]; - -function modelLabel(model: ModelRead) { - return model.display_name || model.model_id; -} - -function capability(model: ModelRead, key: "chat" | "vision" | "image_gen") { - if (key === "chat") return Boolean(model.supports_chat); - if (key === "vision") return Boolean(model.supports_image_input); - return Boolean(model.supports_image_generation); -} - -type ModelCapabilityFilter = "chat" | "vision" | "image_gen"; - -const MODEL_CAPABILITY_FILTERS: { key: ModelCapabilityFilter; label: string }[] = [ - { key: "chat", label: "Chat" }, - { key: "vision", label: "Vision" }, - { key: "image_gen", label: "Image" }, -]; - -function UrlSuggestionCombobox({ - value, - onChange, - placeholder, -}: { - value: string; - onChange: (value: string) => void; - placeholder: string; -}) { - const [open, setOpen] = useState(false); - - return ( - - - - - - - - - - Use the custom URL you typed - - - {URL_SUGGESTIONS.map((url) => ( - { - onChange(url); - setOpen(false); - }} - > - - {url} - - ))} - - - - - - ); -} +import type { ConnectionRead, ModelRead } from "@/contracts/types/model-connections.types"; +import { ConnectionSettingsDialog } from "./model-connections/connection-settings-dialog"; +import { capability, modelLabel } from "./model-connections/model-utils"; +import { ProviderConnectDialog } from "./model-connections/provider-connect-dialog"; +import { + type ConnectionDraft, + PROVIDER_ORDER, + providerDisplay, + providerIcon, +} from "./model-connections/provider-metadata"; function StatusBadge({ connection }: { connection: ConnectionRead }) { if (connection.last_status === "OK") { @@ -198,7 +74,7 @@ function flattenModels(connections: ConnectionRead[]) { return connections.flatMap((connection) => connection.models.map((model) => ({ ...model, - connectionName: connection.provider, + connectionName: providerDisplay(connection.provider).name, connectionId: connection.id, provider: connection.provider, })) @@ -206,110 +82,21 @@ function flattenModels(connections: ConnectionRead[]) { } function ConnectionCard({ connection }: { connection: ConnectionRead }) { - const verifyConnection = useAtomValue(verifyModelConnectionMutationAtom); - const discoverModels = useAtomValue(discoverConnectionModelsMutationAtom); - const updateConnection = useAtomValue(updateModelConnectionMutationAtom); const deleteConnection = useAtomValue(deleteModelConnectionMutationAtom); - const addManualModel = useAtomValue(addManualModelMutationAtom); - const updateModel = useAtomValue(updateModelMutationAtom); - const bulkUpdateModels = useAtomValue(bulkUpdateModelsMutationAtom); - const allowlist = Array.isArray(connection.extra?.model_ids) - ? (connection.extra.model_ids as string[]) - : []; - const [isSettingsOpen, setIsSettingsOpen] = useState(false); - const [baseUrlDraft, setBaseUrlDraft] = useState(connection.base_url ?? ""); - const [apiKeyDraft, setApiKeyDraft] = useState(""); - const [showApiKey, setShowApiKey] = useState(false); - const [allowlistText, setAllowlistText] = useState(allowlist.join(", ")); - const [manualModelId, setManualModelId] = useState(""); - const [modelFilter, setModelFilter] = useState(null); - - const providerLabel = connection.provider; - const isLocal = - connection.provider === "ollama_chat" || - connection.provider === "lm_studio" || - !connection.base_url?.startsWith("https"); - const filteredModels = modelFilter - ? connection.models.filter((model) => capability(model, modelFilter)) - : connection.models; - const allFilteredModelsEnabled = - filteredModels.length > 0 && filteredModels.every((model) => model.enabled); - const hasConnectionChanges = - baseUrlDraft.trim() !== (connection.base_url ?? "") || - apiKeyDraft.trim() !== (connection.api_key ?? ""); - - function handleSettingsOpenChange(open: boolean) { - setIsSettingsOpen(open); - if (open) { - setBaseUrlDraft(connection.base_url ?? ""); - setApiKeyDraft(connection.api_key ?? ""); - setShowApiKey(false); - setAllowlistText(allowlist.join(", ")); - } - } - - function saveConnectionSettings() { - const data: ConnectionUpdateRequest = { - base_url: baseUrlDraft.trim() || null, - }; - - if (apiKeyDraft.trim() !== (connection.api_key ?? "")) { - data.api_key = apiKeyDraft.trim() || null; - } - - updateConnection.mutate( - { id: connection.id, data }, - { - onSuccess: () => setApiKeyDraft(""), - } - ); - } - - function saveAllowlist() { - const ids = allowlistText - .split(",") - .map((value) => value.trim()) - .filter(Boolean); - updateConnection.mutate({ - id: connection.id, - data: { extra: { ...(connection.extra ?? {}), model_ids: ids } }, - }); - } - - function addModel() { - const modelId = manualModelId.trim(); - if (!modelId) return; - addManualModel.mutate( - { connectionId: connection.id, data: { model_id: modelId } }, - { onSuccess: () => setManualModelId("") } - ); - } + const providerMeta = providerDisplay(connection.provider); + const providerLabel = providerMeta.name; function deleteCurrentConnection() { deleteConnection.mutate(connection.id); } - function toggleFilteredModels() { - const nextEnabled = !allFilteredModelsEnabled; - const modelIds = filteredModels - .filter((model) => model.enabled !== nextEnabled) - .map((model) => model.id); - - if (modelIds.length === 0) return; - - bulkUpdateModels.mutate({ - connectionId: connection.id, - data: { model_ids: modelIds, enabled: nextEnabled }, - }); - } - return ( -
-
+
+
- {getProviderIcon(providerLabel, { className: "size-4" })} + {providerIcon(connection.provider)} {providerLabel} {connection.scope === "GLOBAL" ? ( @@ -323,253 +110,7 @@ function ConnectionCard({ connection }: { connection: ConnectionRead }) {
- - - - - - -
- {getProviderIcon(providerLabel, { className: "size-5" })} -
- - Configure {providerLabel} - - - Manage credentials and choose which models are available from this provider. - -
-
-
- -
-
-
- - -

- Leave empty to use the provider default endpoint. -

-
- -
- -
- setApiKeyDraft(event.target.value)} - placeholder={connection.has_api_key ? "Saved API key" : "Paste an API key"} - type={showApiKey ? "text" : "password"} - className="pr-11" - /> - -
-
- - {!isLocal ? ( -
- -
- setAllowlistText(event.target.value)} - placeholder="Comma-separated, e.g. anthropic/claude-sonnet-4-5, google/gemini-2.5-pro" - /> - -
-

- Leave empty to discover all models. Recommended for providers with large - catalogs. -

-
- ) : null} - - - -
-
-
-
Models
-

- Select models to make available for this provider. -

-
-
- - -
-
- -
- setManualModelId(event.target.value)} - onKeyDown={(event) => { - if (event.key === "Enter") { - event.preventDefault(); - addModel(); - } - }} - placeholder="Add a model ID manually" - /> - -
- - {connection.models.length > 0 ? ( -
- - Filter models - - {MODEL_CAPABILITY_FILTERS.map((filter) => { - const count = connection.models.filter((model) => - capability(model, filter.key) - ).length; - const isActive = modelFilter === filter.key; - - return ( - - ); - })} -
- ) : null} - -
- {connection.models.length === 0 ? ( -
- No models yet. Use the refresh button to discover models or add one - manually. -
- ) : null} - {filteredModels.length === 0 && modelFilter ? ( -
- No{" "} - {MODEL_CAPABILITY_FILTERS.find( - (filter) => filter.key === modelFilter - )?.label.toLowerCase()}{" "} - models found on this connection. -
- ) : null} -
- {filteredModels.map((model) => ( -
- - updateModel.mutate({ - id: model.id, - data: { enabled: checked === true }, - }) - } - disabled={updateModel.isPending} - /> -
-
- {modelLabel(model)} - {model.source === "MANUAL" ? ( - - manual - - ) : null} -
-
- {["chat", "vision", "image_gen"] - .filter((key) => - capability(model, key as "chat" | "vision" | "image_gen") - ) - .join(", ") || "No discovered capabilities"} -
-
-
- ))} -
-
-
- - {connection.last_status && connection.last_status !== "OK" ? ( -

- {connection.last_error || "Could not list models."} Chat may still work; add - model IDs manually if discovery is unavailable. -

- ) : null} -
-
- - - - - -
-
+ -
-
-
+
+
+
+

Add Provider

- {selectedProvider - ? `${selectedProvider.transport} transport, ${selectedProvider.discovery} discovery.` - : "Choose a provider preset."}{" "} - Base URL is explicit and editable. Local URLs are tested from the backend container, - so use host.docker.internal instead of localhost. + SurfSense supports popular providers and self-hosted model endpoints.

+
+ {sortedProviders.map((item) => { + const meta = providerDisplay(item.provider); - {connections.length > 0 ? ( + return ( + + ); + })} +
+
+ + + + {connections.length > 0 ? ( +
+ +

Available Providers

- -

Available Providers

-
- {connections.map((connection) => ( - - ))} -
+ {connections.map((connection) => ( + + ))}
- ) : null} - - +
+ ) : null} +
diff --git a/surfsense_web/components/settings/model-connections/azure-connect-form.tsx b/surfsense_web/components/settings/model-connections/azure-connect-form.tsx new file mode 100644 index 000000000..11a2e25d3 --- /dev/null +++ b/surfsense_web/components/settings/model-connections/azure-connect-form.tsx @@ -0,0 +1,59 @@ +import { useState } from "react"; +import { Input } from "@/components/ui/input"; +import { Label } from "@/components/ui/label"; +import { ApiKeyField, ConnectFormFooter } from "./connect-fields"; +import { + isValidAzureTargetUri, + type ProviderConnectFormProps, + parseAzureTargetUri, +} from "./provider-metadata"; + +/** + * Azure OpenAI connect form. The user pastes a single Target URI, which we parse + * into api base, api version, and the deployment name (seeded as the model). + */ +export function AzureConnectForm({ isPending, onCancel, onSubmit }: ProviderConnectFormProps) { + const [targetUri, setTargetUri] = useState(""); + const [apiKey, setApiKey] = useState(""); + const canSubmit = isValidAzureTargetUri(targetUri) && Boolean(apiKey.trim()); + + function handleSubmit() { + const parsed = parseAzureTargetUri(targetUri); + onSubmit({ + base_url: parsed?.origin ?? null, + api_key: apiKey || null, + extra: parsed?.apiVersion ? { api_version: parsed.apiVersion } : {}, + seedModelId: parsed?.deploymentName || undefined, + }); + } + + return ( + <> +
+
+ + setTargetUri(event.target.value)} + placeholder="https://your-resource.cognitiveservices.azure.com/openai/deployments/deployment-name/chat/completions?api-version=2025-01-01-preview" + /> +

+ Paste your endpoint target URI from Azure OpenAI (including API base, deployment name, + and API version). +

+
+ +
+ + + ); +} diff --git a/surfsense_web/components/settings/model-connections/bedrock-connect-form.tsx b/surfsense_web/components/settings/model-connections/bedrock-connect-form.tsx new file mode 100644 index 000000000..9466c7cd1 --- /dev/null +++ b/surfsense_web/components/settings/model-connections/bedrock-connect-form.tsx @@ -0,0 +1,134 @@ +import { useState } from "react"; +import { Input } from "@/components/ui/input"; +import { Label } from "@/components/ui/label"; +import { + Select, + SelectContent, + SelectItem, + SelectTrigger, + SelectValue, +} from "@/components/ui/select"; +import { ConnectFormFooter } from "./connect-fields"; +import { + AWS_REGION_OPTIONS, + BEDROCK_AUTH_ACCESS_KEY, + BEDROCK_AUTH_IAM, + BEDROCK_AUTH_LONG_TERM_API_KEY, + type ProviderConnectFormProps, +} from "./provider-metadata"; + +/** + * Amazon Bedrock connect form. Region + auth method drive which AWS credentials + * are collected; everything rides along in `extra.litellm_params`. + */ +export function BedrockConnectForm({ isPending, onCancel, onSubmit }: ProviderConnectFormProps) { + const [region, setRegion] = useState(""); + const [authMethod, setAuthMethod] = useState(BEDROCK_AUTH_ACCESS_KEY); + const [accessKeyId, setAccessKeyId] = useState(""); + const [secretAccessKey, setSecretAccessKey] = useState(""); + const [bearerToken, setBearerToken] = useState(""); + + const canSubmit = (() => { + if (!region) return false; + if (authMethod === BEDROCK_AUTH_ACCESS_KEY) { + return Boolean(accessKeyId && secretAccessKey); + } + if (authMethod === BEDROCK_AUTH_LONG_TERM_API_KEY) { + return Boolean(bearerToken); + } + return true; + })(); + + function handleSubmit() { + const params: Record = { aws_region_name: region }; + if (authMethod === BEDROCK_AUTH_ACCESS_KEY) { + params.aws_access_key_id = accessKeyId; + params.aws_secret_access_key = secretAccessKey; + } else if (authMethod === BEDROCK_AUTH_LONG_TERM_API_KEY) { + params.aws_bearer_token_bedrock = bearerToken; + } + onSubmit({ base_url: null, api_key: null, extra: { litellm_params: params } }); + } + + return ( + <> +
+
+ + +
+
+ + +
+ {authMethod === BEDROCK_AUTH_ACCESS_KEY ? ( + <> +
+ + setAccessKeyId(event.target.value)} + placeholder="AKIAIOSFODNN7EXAMPLE" + /> +
+
+ + setSecretAccessKey(event.target.value)} + placeholder="wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY" + type="password" + /> +
+ + ) : null} + {authMethod === BEDROCK_AUTH_LONG_TERM_API_KEY ? ( +
+ + setBearerToken(event.target.value)} + placeholder="Your long-term API key" + type="password" + /> +
+ ) : null} + {authMethod === BEDROCK_AUTH_IAM ? ( +

+ SurfSense will use the IAM role attached to the environment it's running in to + authenticate. +

+ ) : null} +

+ Add Bedrock model IDs from the provider's settings after connecting. +

+
+ + + ); +} diff --git a/surfsense_web/components/settings/model-connections/connect-fields.tsx b/surfsense_web/components/settings/model-connections/connect-fields.tsx new file mode 100644 index 000000000..af8db7f12 --- /dev/null +++ b/surfsense_web/components/settings/model-connections/connect-fields.tsx @@ -0,0 +1,83 @@ +import { Button } from "@/components/ui/button"; +import { DialogFooter } from "@/components/ui/dialog"; +import { Input } from "@/components/ui/input"; +import { Label } from "@/components/ui/label"; + +interface ApiBaseUrlFieldProps { + value: string; + onChange: (value: string) => void; + optional?: boolean; + /** Placeholder, typically the provider's prefilled default base URL. */ + placeholder?: string; +} + +/** Shared API Base URL input. The prefilled default is passed in via `value`. */ +export function ApiBaseUrlField({ value, onChange, optional, placeholder }: ApiBaseUrlFieldProps) { + return ( +
+ + onChange(event.target.value)} + placeholder={placeholder || "https://api.example.com/v1"} + /> +

+ Local URLs are tested from the backend container, so use host.docker.internal instead of + localhost. +

+
+ ); +} + +interface ApiKeyFieldProps { + value: string; + onChange: (value: string) => void; + label?: string; + placeholder?: string; +} + +/** Shared masked API Key input. */ +export function ApiKeyField({ + value, + onChange, + label = "API Key", + placeholder = "API key", +}: ApiKeyFieldProps) { + return ( +
+ + onChange(event.target.value)} + placeholder={placeholder} + type="password" + /> +
+ ); +} + +interface ConnectFormFooterProps { + onCancel: () => void; + onSubmit: () => void; + canSubmit: boolean; + isPending: boolean; +} + +/** Shared Cancel / Connect footer for every provider connect form. */ +export function ConnectFormFooter({ + onCancel, + onSubmit, + canSubmit, + isPending, +}: ConnectFormFooterProps) { + return ( + + + + + ); +} diff --git a/surfsense_web/components/settings/model-connections/connection-settings-dialog.tsx b/surfsense_web/components/settings/model-connections/connection-settings-dialog.tsx new file mode 100644 index 000000000..f3821af46 --- /dev/null +++ b/surfsense_web/components/settings/model-connections/connection-settings-dialog.tsx @@ -0,0 +1,249 @@ +import { useAtomValue } from "jotai"; +import { Eye, EyeOff, Settings } from "lucide-react"; +import { useState } from "react"; +import { + addManualModelMutationAtom, + bulkUpdateModelsMutationAtom, + discoverConnectionModelsMutationAtom, + updateModelConnectionMutationAtom, + updateModelMutationAtom, + verifyModelConnectionMutationAtom, +} from "@/atoms/model-connections/model-connections-mutation.atoms"; +import { Button } from "@/components/ui/button"; +import { + Dialog, + DialogContent, + DialogDescription, + DialogFooter, + DialogHeader, + DialogTitle, + DialogTrigger, +} from "@/components/ui/dialog"; +import { Input } from "@/components/ui/input"; +import { Label } from "@/components/ui/label"; +import { Separator } from "@/components/ui/separator"; +import type { + ConnectionRead, + ConnectionUpdateRequest, + ModelRead, +} from "@/contracts/types/model-connections.types"; +import { ModelsSelectionPanel } from "./models-selection-panel"; +import { providerIcon } from "./provider-metadata"; + +interface ConnectionSettingsDialogProps { + connection: ConnectionRead; + providerLabel: string; +} + +export function ConnectionSettingsDialog({ + connection, + providerLabel, +}: ConnectionSettingsDialogProps) { + const verifyConnection = useAtomValue(verifyModelConnectionMutationAtom); + const discoverModels = useAtomValue(discoverConnectionModelsMutationAtom); + const updateConnection = useAtomValue(updateModelConnectionMutationAtom); + const addManualModel = useAtomValue(addManualModelMutationAtom); + const updateModel = useAtomValue(updateModelMutationAtom); + const bulkUpdateModels = useAtomValue(bulkUpdateModelsMutationAtom); + + const allowlist = Array.isArray(connection.extra?.model_ids) + ? (connection.extra.model_ids as string[]) + : []; + const [isOpen, setIsOpen] = useState(false); + const [baseUrlDraft, setBaseUrlDraft] = useState(connection.base_url ?? ""); + const [apiKeyDraft, setApiKeyDraft] = useState(""); + const [showApiKey, setShowApiKey] = useState(false); + const [allowlistText, setAllowlistText] = useState(allowlist.join(", ")); + + const isLocal = + connection.provider === "ollama_chat" || + connection.provider === "lm_studio" || + !connection.base_url?.startsWith("https"); + const hasConnectionChanges = + baseUrlDraft.trim() !== (connection.base_url ?? "") || + apiKeyDraft.trim() !== (connection.api_key ?? ""); + + function handleOpenChange(open: boolean) { + setIsOpen(open); + if (open) { + setBaseUrlDraft(connection.base_url ?? ""); + setApiKeyDraft(connection.api_key ?? ""); + setShowApiKey(false); + setAllowlistText(allowlist.join(", ")); + } + } + + function saveConnectionSettings() { + const data: ConnectionUpdateRequest = { + base_url: baseUrlDraft.trim() || null, + }; + + if (apiKeyDraft.trim() !== (connection.api_key ?? "")) { + data.api_key = apiKeyDraft.trim() || null; + } + + updateConnection.mutate( + { id: connection.id, data }, + { + onSuccess: () => setApiKeyDraft(""), + } + ); + } + + function saveAllowlist() { + const ids = allowlistText + .split(",") + .map((value) => value.trim()) + .filter(Boolean); + updateConnection.mutate({ + id: connection.id, + data: { extra: { ...(connection.extra ?? {}), model_ids: ids } }, + }); + } + + function handleToggleModel(model: ModelRead, enabled: boolean) { + updateModel.mutate({ + id: model.id, + data: { enabled }, + }); + } + + function handleBulkToggle(models: ModelRead[], enabled: boolean) { + bulkUpdateModels.mutate({ + connectionId: connection.id, + data: { model_ids: models.map((model) => model.id), enabled }, + }); + } + + return ( + + + + + + +
+ {providerIcon(connection.provider, "size-5")} +
+ + Configure {providerLabel} + + + Manage credentials and choose which models are available from this provider. + +
+
+
+ +
+
+
+ + setBaseUrlDraft(event.target.value)} + placeholder="https://api.example.com/v1" + /> +

+ Leave empty to use the provider default endpoint. +

+
+ +
+ +
+ setApiKeyDraft(event.target.value)} + placeholder={connection.has_api_key ? "Saved API key" : "Paste an API key"} + type={showApiKey ? "text" : "password"} + className="pr-11" + /> + +
+
+ + {!isLocal ? ( +
+ +
+ setAllowlistText(event.target.value)} + placeholder="Comma-separated, e.g. anthropic/claude-sonnet-4-5, google/gemini-2.5-pro" + /> + +
+

+ Leave empty to discover all models. Recommended for providers with large catalogs. +

+
+ ) : null} + + + + discoverModels.mutate(connection.id)} + onAddManual={(modelId) => + addManualModel.mutate({ + connectionId: connection.id, + data: { model_id: modelId }, + }) + } + onToggleModel={handleToggleModel} + onBulkToggle={handleBulkToggle} + /> + + {connection.last_status && connection.last_status !== "OK" ? ( +

+ {connection.last_error || "Could not list models."} Chat may still work; add model + IDs manually if discovery is unavailable. +

+ ) : null} +
+
+ + + + + +
+
+ ); +} diff --git a/surfsense_web/components/settings/model-connections/default-connect-form.tsx b/surfsense_web/components/settings/model-connections/default-connect-form.tsx new file mode 100644 index 000000000..3f261c6b2 --- /dev/null +++ b/surfsense_web/components/settings/model-connections/default-connect-form.tsx @@ -0,0 +1,51 @@ +import { useState } from "react"; +import { ApiBaseUrlField, ApiKeyField, ConnectFormFooter } from "./connect-fields"; +import type { ProviderConnectFormProps } from "./provider-metadata"; + +/** + * Connect form for OpenAI-compatible / native key providers (OpenAI, Anthropic, + * OpenRouter, OpenAI-Compatible, LM Studio, Ollama, …). The base URL is + * prefilled from the provider default. + */ +export function DefaultConnectForm({ + provider, + defaultBaseUrl, + baseUrlRequired, + isPending, + onCancel, + onSubmit, +}: ProviderConnectFormProps) { + const [baseUrl, setBaseUrl] = useState(defaultBaseUrl); + const [apiKey, setApiKey] = useState(""); + const isOllama = provider === "ollama_chat"; + const canSubmit = !(baseUrlRequired && !baseUrl.trim()); + + function handleSubmit() { + onSubmit({ base_url: baseUrl || null, api_key: apiKey || null, extra: {} }); + } + + return ( + <> +
+ + +
+ + + ); +} diff --git a/surfsense_web/components/settings/model-connections/model-utils.ts b/surfsense_web/components/settings/model-connections/model-utils.ts new file mode 100644 index 000000000..1db14b3eb --- /dev/null +++ b/surfsense_web/components/settings/model-connections/model-utils.ts @@ -0,0 +1,25 @@ +import type { ModelRead } from "@/contracts/types/model-connections.types"; + +export type ModelCapabilityFilter = "chat" | "vision" | "image_gen"; + +export const MODEL_CAPABILITY_FILTERS: { key: ModelCapabilityFilter; label: string }[] = [ + { key: "chat", label: "Chat" }, + { key: "vision", label: "Vision" }, + { key: "image_gen", label: "Image" }, +]; + +export function modelLabel(model: ModelRead) { + return model.display_name || model.model_id; +} + +export function capability(model: ModelRead, key: ModelCapabilityFilter) { + if (key === "chat") return Boolean(model.supports_chat); + if (key === "vision") return Boolean(model.supports_image_input); + return Boolean(model.supports_image_generation); +} + +export function capabilityLabels(model: ModelRead) { + return MODEL_CAPABILITY_FILTERS.filter((filter) => capability(model, filter.key)) + .map((filter) => filter.label.toLowerCase()) + .join(", "); +} diff --git a/surfsense_web/components/settings/model-connections/models-selection-panel.tsx b/surfsense_web/components/settings/model-connections/models-selection-panel.tsx new file mode 100644 index 000000000..20fbe862f --- /dev/null +++ b/surfsense_web/components/settings/model-connections/models-selection-panel.tsx @@ -0,0 +1,196 @@ +import { RefreshCcw } from "lucide-react"; +import { useState } from "react"; +import { Badge } from "@/components/ui/badge"; +import { Button } from "@/components/ui/button"; +import { Checkbox } from "@/components/ui/checkbox"; +import { Input } from "@/components/ui/input"; +import type { ModelRead } from "@/contracts/types/model-connections.types"; +import { + capability, + capabilityLabels, + MODEL_CAPABILITY_FILTERS, + type ModelCapabilityFilter, + modelLabel, +} from "./model-utils"; + +interface ModelsSelectionPanelProps { + models: ModelRead[]; + description?: string; + emptyMessage?: string; + manualInputPlaceholder?: string; + refreshLabel?: string; + isRefreshing?: boolean; + isAddingManual?: boolean; + isUpdatingModel?: boolean; + isBulkUpdating?: boolean; + onRefresh?: () => void; + onAddManual?: (modelId: string) => void; + onToggleModel?: (model: ModelRead, enabled: boolean) => void; + onBulkToggle?: (models: ModelRead[], enabled: boolean) => void; +} + +export function ModelsSelectionPanel({ + models, + description = "Select models to make available for this provider.", + emptyMessage = "No models yet. Use the refresh button to discover models or add one manually.", + manualInputPlaceholder = "Add a model ID manually", + refreshLabel = "Refresh models", + isRefreshing = false, + isAddingManual = false, + isUpdatingModel = false, + isBulkUpdating = false, + onRefresh, + onAddManual, + onToggleModel, + onBulkToggle, +}: ModelsSelectionPanelProps) { + const [manualModelId, setManualModelId] = useState(""); + const [modelFilter, setModelFilter] = useState(null); + + const filteredModels = modelFilter + ? models.filter((model) => capability(model, modelFilter)) + : models; + const allFilteredModelsEnabled = + filteredModels.length > 0 && filteredModels.every((model) => model.enabled); + + function addModel() { + const modelId = manualModelId.trim(); + if (!modelId || !onAddManual) return; + onAddManual(modelId); + setManualModelId(""); + } + + function toggleFilteredModels() { + const nextEnabled = !allFilteredModelsEnabled; + const changedModels = filteredModels.filter((model) => model.enabled !== nextEnabled); + if (changedModels.length === 0) return; + onBulkToggle?.(changedModels, nextEnabled); + } + + return ( +
+
+
+
Models
+

{description}

+
+
+ + {onRefresh ? ( + + ) : null} +
+
+ + {onAddManual ? ( +
+ setManualModelId(event.target.value)} + onKeyDown={(event) => { + if (event.key === "Enter") { + event.preventDefault(); + addModel(); + } + }} + placeholder={manualInputPlaceholder} + /> + +
+ ) : null} + + {models.length > 0 ? ( +
+ Filter models + {MODEL_CAPABILITY_FILTERS.map((filter) => { + const count = models.filter((model) => capability(model, filter.key)).length; + const isActive = modelFilter === filter.key; + + return ( + + ); + })} +
+ ) : null} + +
+ {models.length === 0 ? ( +
+ {emptyMessage} +
+ ) : null} + {filteredModels.length === 0 && modelFilter ? ( +
+ No{" "} + {MODEL_CAPABILITY_FILTERS.find( + (filter) => filter.key === modelFilter + )?.label.toLowerCase()}{" "} + models found on this connection. +
+ ) : null} +
+ {filteredModels.map((model) => ( +
+ onToggleModel?.(model, checked === true)} + disabled={!onToggleModel || isUpdatingModel} + /> +
+
+ {modelLabel(model)} + {model.source === "MANUAL" ? ( + + manual + + ) : null} +
+
+ {capabilityLabels(model) || "No discovered capabilities"} +
+
+
+ ))} +
+
+
+ ); +} diff --git a/surfsense_web/components/settings/model-connections/provider-connect-dialog.tsx b/surfsense_web/components/settings/model-connections/provider-connect-dialog.tsx new file mode 100644 index 000000000..871a66cc5 --- /dev/null +++ b/surfsense_web/components/settings/model-connections/provider-connect-dialog.tsx @@ -0,0 +1,154 @@ +import { Button } from "@/components/ui/button"; +import { + Dialog, + DialogContent, + DialogDescription, + DialogFooter, + DialogHeader, + DialogTitle, +} from "@/components/ui/dialog"; +import type { + ConnectionRead, + ModelProviderRead, + ModelRead, +} from "@/contracts/types/model-connections.types"; +import { AzureConnectForm } from "./azure-connect-form"; +import { BedrockConnectForm } from "./bedrock-connect-form"; +import { DefaultConnectForm } from "./default-connect-form"; +import { ModelsSelectionPanel } from "./models-selection-panel"; +import { + type ConnectionDraft, + type ProviderConnectFormProps, + providerDefaultBaseUrl, + providerDisplay, + providerIcon, +} from "./provider-metadata"; +import { VertexConnectForm } from "./vertex-connect-form"; + +interface ProviderConnectDialogProps { + open: boolean; + onOpenChange: (open: boolean) => void; + provider: string; + selectedProvider?: ModelProviderRead; + isPending: boolean; + onSubmit: (draft: ConnectionDraft) => void; + connectedConnection?: ConnectionRead | null; + connectModels?: ModelRead[]; + isDiscoveringModels?: boolean; + isAddingManualModel?: boolean; + isUpdatingModel?: boolean; + isBulkUpdatingModels?: boolean; + onRefreshModels?: () => void; + onAddManualModel?: (modelId: string) => void; + onToggleModel?: (model: ModelRead, enabled: boolean) => void; + onBulkToggleModels?: (models: ModelRead[], enabled: boolean) => void; + onDone?: () => void; +} + +/** + * Shared dialog shell for the "Add Provider" flow. It owns the header and routes + * to the provider-specific connect form. Forms remount on open (Radix unmounts + * closed content), so each gets fresh, prefilled state. + */ +export function ProviderConnectDialog({ + open, + onOpenChange, + provider, + selectedProvider, + isPending, + onSubmit, + connectedConnection, + connectModels = [], + isDiscoveringModels = false, + isAddingManualModel = false, + isUpdatingModel = false, + isBulkUpdatingModels = false, + onRefreshModels, + onAddManualModel, + onToggleModel, + onBulkToggleModels, + onDone, +}: ProviderConnectDialogProps) { + const meta = providerDisplay(provider); + const isModelSelectionStep = Boolean(connectedConnection); + + const formProps: ProviderConnectFormProps = { + provider, + defaultBaseUrl: providerDefaultBaseUrl(provider, selectedProvider?.default_base_url), + baseUrlRequired: Boolean(selectedProvider?.base_url_required), + isPending, + onCancel: () => onOpenChange(false), + onSubmit, + }; + + return ( + + + +
+ {providerIcon(provider, "size-5")} +
+ + {isModelSelectionStep ? `Select ${meta.name} models` : `Connect ${meta.name}`} + + + {isModelSelectionStep + ? selectedProvider?.discovery === "static" + ? "Choose from known model IDs or add one manually." + : "Choose which discovered models should be available in this search space." + : meta.subtitle} + +
+
+
+ {isModelSelectionStep ? ( + <> +
+ +
+ + + + + ) : ( +
+ {provider === "azure" ? ( + + ) : provider === "bedrock" ? ( + + ) : provider === "vertex_ai" ? ( + + ) : ( + + )} +
+ )} +
+
+ ); +} diff --git a/surfsense_web/components/settings/model-connections/provider-metadata.tsx b/surfsense_web/components/settings/model-connections/provider-metadata.tsx new file mode 100644 index 000000000..0ca8ae419 --- /dev/null +++ b/surfsense_web/components/settings/model-connections/provider-metadata.tsx @@ -0,0 +1,139 @@ +import { getProviderIcon } from "@/lib/provider-icons"; + +export const PROVIDER_ORDER = [ + "openai", + "anthropic", + "vertex_ai", + "bedrock", + "azure", + "openrouter", + "ollama_chat", + "lm_studio", + "openai_compatible", +]; + +export const PROVIDER_DISPLAY: Record< + string, + { name: string; subtitle: string; iconKey?: string; defaultBaseUrl?: string } +> = { + anthropic: { + name: "Claude", + subtitle: "Anthropic", + iconKey: "anthropic", + defaultBaseUrl: "https://api.anthropic.com/v1", + }, + azure: { name: "Azure OpenAI", subtitle: "Microsoft Azure", iconKey: "azure_openai" }, + bedrock: { name: "Amazon Bedrock", subtitle: "AWS", iconKey: "bedrock" }, + lm_studio: { name: "LM Studio", subtitle: "LM Studio", iconKey: "custom" }, + ollama_chat: { name: "Ollama", subtitle: "Ollama", iconKey: "ollama" }, + openai: { + name: "GPT", + subtitle: "OpenAI", + iconKey: "openai", + defaultBaseUrl: "https://api.openai.com/v1", + }, + openai_compatible: { + name: "OpenAI-Compatible", + subtitle: "OpenAI-compatible endpoint", + iconKey: "custom", + }, + openrouter: { + name: "OpenRouter", + subtitle: "OpenRouter", + iconKey: "openrouter", + defaultBaseUrl: "https://openrouter.ai/api/v1", + }, + vertex_ai: { name: "Gemini", subtitle: "Google Cloud Vertex AI", iconKey: "vertex_ai" }, +}; + +export function providerDisplay(provider: string) { + const fallback = provider + .split("_") + .filter(Boolean) + .map((part) => part.charAt(0).toUpperCase() + part.slice(1)) + .join(" "); + + return ( + PROVIDER_DISPLAY[provider] ?? { + name: fallback || provider, + subtitle: provider, + iconKey: provider, + } + ); +} + +export function providerIcon(provider: string, className = "size-4") { + return getProviderIcon(providerDisplay(provider).iconKey ?? provider, { className }); +} + +export function providerDefaultBaseUrl(provider: string, registryDefault?: string | null) { + return registryDefault ?? PROVIDER_DISPLAY[provider]?.defaultBaseUrl ?? ""; +} + +export const AWS_REGION_OPTIONS = [ + "us-east-1", + "us-east-2", + "us-west-2", + "us-gov-east-1", + "us-gov-west-1", + "ap-northeast-1", + "ap-south-1", + "ap-southeast-1", + "ap-southeast-2", + "ap-east-1", + "ca-central-1", + "eu-central-1", + "eu-west-2", +]; + +export const VERTEX_DEFAULT_LOCATION = "global"; + +export const BEDROCK_AUTH_IAM = "iam"; +export const BEDROCK_AUTH_ACCESS_KEY = "access_key"; +export const BEDROCK_AUTH_LONG_TERM_API_KEY = "long_term_api_key"; + +export const VERTEX_AUTH_SERVICE_ACCOUNT = "service_account_json"; +export const VERTEX_AUTH_WORKLOAD_IDENTITY = "workload_identity"; + +// Mirrors Onyx's Azure "Target URI" parser: the user pastes the full endpoint +// (e.g. https://res.cognitiveservices.azure.com/openai/deployments//chat/completions?api-version=) +// which we split into api base (origin), api version, and deployment name. +export function parseAzureTargetUri(rawUri: string) { + try { + const url = new URL(rawUri); + const deploymentMatch = url.pathname.match(/\/openai\/deployments\/([^/]+)/i); + return { + origin: url.origin, + apiVersion: url.searchParams.get("api-version")?.trim() ?? "", + deploymentName: deploymentMatch?.[1] ? deploymentMatch[1].toLowerCase() : "", + isResponsesPath: /\/openai\/responses/i.test(url.pathname), + }; + } catch { + return null; + } +} + +export function isValidAzureTargetUri(rawUri: string) { + const parsed = parseAzureTargetUri(rawUri); + if (!parsed) return false; + return Boolean(parsed.apiVersion) && (Boolean(parsed.deploymentName) || parsed.isResponsesPath); +} + +/** Connection payload produced by a provider connect form. */ +export interface ConnectionDraft { + base_url: string | null; + api_key: string | null; + extra: Record; + /** Model id to seed after creation (providers without discovery, e.g. Azure). */ + seedModelId?: string; +} + +/** Props shared by every provider-specific connect form. */ +export interface ProviderConnectFormProps { + provider: string; + defaultBaseUrl: string; + baseUrlRequired: boolean; + isPending: boolean; + onCancel: () => void; + onSubmit: (draft: ConnectionDraft) => void; +} diff --git a/surfsense_web/components/settings/model-connections/vertex-connect-form.tsx b/surfsense_web/components/settings/model-connections/vertex-connect-form.tsx new file mode 100644 index 000000000..096d3df2e --- /dev/null +++ b/surfsense_web/components/settings/model-connections/vertex-connect-form.tsx @@ -0,0 +1,127 @@ +import { useState } from "react"; +import { Input } from "@/components/ui/input"; +import { Label } from "@/components/ui/label"; +import { + Select, + SelectContent, + SelectItem, + SelectTrigger, + SelectValue, +} from "@/components/ui/select"; +import { ConnectFormFooter } from "./connect-fields"; +import { + type ProviderConnectFormProps, + VERTEX_AUTH_SERVICE_ACCOUNT, + VERTEX_AUTH_WORKLOAD_IDENTITY, + VERTEX_DEFAULT_LOCATION, +} from "./provider-metadata"; + +/** + * Google Vertex AI (Gemini) connect form. Service-account auth uploads a + * credentials JSON file (read into a string); workload identity collects a + * project id. Credentials ride along in `extra.litellm_params`. + */ +export function VertexConnectForm({ isPending, onCancel, onSubmit }: ProviderConnectFormProps) { + const [authMethod, setAuthMethod] = useState(VERTEX_AUTH_SERVICE_ACCOUNT); + const [location, setLocation] = useState(VERTEX_DEFAULT_LOCATION); + const [credentials, setCredentials] = useState(""); + const [project, setProject] = useState(""); + + const canSubmit = + authMethod === VERTEX_AUTH_SERVICE_ACCOUNT ? Boolean(credentials) : Boolean(project); + + async function handleCredentialsFile(file: File | undefined) { + if (!file) return; + setCredentials(await file.text()); + } + + function handleSubmit() { + const params: Record = {}; + if (location) params.vertex_location = location; + if (authMethod === VERTEX_AUTH_SERVICE_ACCOUNT) { + if (credentials) params.vertex_credentials = credentials; + } else if (project) { + params.vertex_project = project; + } + onSubmit({ base_url: null, api_key: null, extra: { litellm_params: params } }); + } + + return ( + <> +
+
+ + +
+
+ + setLocation(event.target.value)} + placeholder={VERTEX_DEFAULT_LOCATION} + /> +

+ Region where your Google Vertex AI models are hosted. +

+
+ {authMethod === VERTEX_AUTH_SERVICE_ACCOUNT ? ( +
+ + handleCredentialsFile(event.target.files?.[0])} + /> + +

+ {credentials + ? "Credentials file loaded." + : "Attach your service account key JSON from Google Cloud."} +

+
+ ) : ( +
+ + setProject(event.target.value)} + placeholder="my-vertex-project" + /> +

+ The GCP project where Vertex AI is enabled. +

+
+ )} +

+ Add Vertex AI model IDs from the provider's settings after connecting. +

+
+ + + ); +} diff --git a/surfsense_web/lib/provider-icons.tsx b/surfsense_web/lib/provider-icons.tsx index e63c5eb2f..3bb310904 100644 --- a/surfsense_web/lib/provider-icons.tsx +++ b/surfsense_web/lib/provider-icons.tsx @@ -1,4 +1,4 @@ -import { Bot, Shuffle } from "lucide-react"; +import { Cpu, Shuffle } from "lucide-react"; import { Ai21Icon, AnthropicIcon, @@ -72,7 +72,7 @@ export function getProviderIcon( case "COMETAPI": return ; case "CUSTOM": - return ; + return ; case "DATABRICKS": return ; case "DEEPINFRA": @@ -122,6 +122,6 @@ export function getProviderIcon( case "ZHIPU": return ; default: - return ; + return ; } } From 407f2a9612acbab423a5b490e1caa079158a3ecb Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Fri, 12 Jun 2026 22:41:21 +0530 Subject: [PATCH 31/59] feat(model-connections): enhance model connection functionality with preview and selection features --- .../app/routes/model_connections_routes.py | 84 +++++++- surfsense_backend/app/schemas/__init__.py | 2 + .../app/schemas/model_connections.py | 27 +++ .../app/services/model_connection_service.py | 45 ++++ .../app/services/provider_registry.py | 3 +- .../model-connections-mutation.atoms.ts | 11 + .../agent-action-log/action-log-dialog.tsx | 4 +- .../settings/model-connections-settings.tsx | 192 ++++++++++-------- .../model-connections/azure-connect-form.tsx | 65 +++--- .../bedrock-connect-form.tsx | 144 ++++++------- .../model-connections/connect-fields.tsx | 32 ++- .../connection-settings-dialog.tsx | 20 +- .../default-connect-form.tsx | 48 ++--- .../settings/model-connections/model-utils.ts | 13 +- .../models-selection-panel.tsx | 10 +- .../provider-connect-dialog.tsx | 171 ++++++++-------- .../model-connections/provider-metadata.tsx | 4 +- .../model-connections/vertex-connect-form.tsx | 149 +++++++------- .../types/model-connections.types.ts | 19 ++ .../lib/apis/model-connections-api.service.ts | 16 ++ 20 files changed, 630 insertions(+), 429 deletions(-) diff --git a/surfsense_backend/app/routes/model_connections_routes.py b/surfsense_backend/app/routes/model_connections_routes.py index 730c68565..2405843a7 100644 --- a/surfsense_backend/app/routes/model_connections_routes.py +++ b/surfsense_backend/app/routes/model_connections_routes.py @@ -21,10 +21,12 @@ from app.schemas import ( ConnectionRead, ConnectionUpdate, ModelCreate, + ModelPreviewRead, ModelProviderRead, ModelRead, ModelRolesRead, ModelRolesUpdate, + ModelSelection, ModelsBulkUpdate, ModelUpdate, VerifyConnectionResponse, @@ -48,6 +50,21 @@ def _model_read(model: Model | dict) -> ModelRead: return ModelRead.model_validate(model) +def _preview_model_read(item: dict) -> ModelPreviewRead: + return ModelPreviewRead( + model_id=item["model_id"], + display_name=item.get("display_name"), + source=item.get("source", ModelSource.DISCOVERED), + supports_chat=item.get("supports_chat"), + max_input_tokens=item.get("max_input_tokens"), + supports_image_input=item.get("supports_image_input"), + supports_tools=item.get("supports_tools"), + supports_image_generation=item.get("supports_image_generation"), + enabled=item.get("enabled", False), + metadata=item.get("metadata") or item.get("catalog") or {}, + ) + + def _connection_read(conn: Connection | dict, models: list[Model | dict] | None = None) -> ConnectionRead: if isinstance(conn, dict): payload = { @@ -86,6 +103,25 @@ def _apply_model_facts(model: Model, facts: dict) -> None: model.supports_image_generation = facts.get("supports_image_generation") +def _selection_to_model(conn: Connection, selection: ModelSelection) -> Model: + source = ( + selection.source + if isinstance(selection.source, ModelSource) + else ModelSource(selection.source) + ) + model = Model( + connection_id=conn.id, + model_id=selection.model_id.strip(), + display_name=selection.display_name, + source=source, + capabilities_override={}, + enabled=selection.enabled, + catalog=selection.metadata, + ) + _apply_model_facts(model, selection.model_dump()) + return model + + def _default_model_for(models: list[Model], capability: str) -> int | None: for model in models: if model.enabled and has_capability(model, capability): @@ -226,7 +262,7 @@ async def create_connection( Permission.LLM_CONFIGS_CREATE.value, "You don't have permission to create model connections in this search space", ) - payload = data.model_dump(exclude={"search_space_id"}) + payload = data.model_dump(exclude={"search_space_id", "models"}) conn = Connection( **payload, @@ -234,9 +270,51 @@ async def create_connection( user_id=user.id, ) session.add(conn) + await session.flush() + + seen_model_ids: set[str] = set() + for selection in data.models: + model_id = selection.model_id.strip() + if not model_id or model_id in seen_model_ids: + continue + seen_model_ids.add(model_id) + session.add(_selection_to_model(conn, selection)) + await session.commit() - await session.refresh(conn) - return _connection_read(conn, []) + conn = await _load_connection(session, conn.id) + await _default_unset_roles(session, conn, list(conn.models)) + await session.commit() + conn = await _load_connection(session, conn.id) + return _connection_read(conn, list(conn.models)) + + +@router.post("/model-connections/discover-preview", response_model=list[ModelPreviewRead]) +async def preview_connection_models( + data: ConnectionCreate, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +): + if data.scope == ConnectionScope.SEARCH_SPACE and data.search_space_id is not None: + await check_permission( + session, + user, + data.search_space_id, + Permission.LLM_CONFIGS_CREATE.value, + "You don't have permission to create model connections in this search space", + ) + + draft = Connection( + provider=data.provider, + base_url=data.base_url, + api_key=data.api_key, + extra=data.extra or {}, + scope=data.scope, + enabled=data.enabled, + search_space_id=data.search_space_id if data.scope == ConnectionScope.SEARCH_SPACE else None, + user_id=user.id, + ) + discovered = await discover_models(draft) + return [_preview_model_read(item) for item in discovered] @router.put("/model-connections/{connection_id}", response_model=ConnectionRead) diff --git a/surfsense_backend/app/schemas/__init__.py b/surfsense_backend/app/schemas/__init__.py index 55e712f12..efa448dcd 100644 --- a/surfsense_backend/app/schemas/__init__.py +++ b/surfsense_backend/app/schemas/__init__.py @@ -49,10 +49,12 @@ from .model_connections import ( ConnectionRead, ConnectionUpdate, ModelCreate, + ModelPreviewRead, ModelProviderRead, ModelRead, ModelRolesRead, ModelRolesUpdate, + ModelSelection, ModelsBulkUpdate, ModelUpdate, VerifyConnectionResponse, diff --git a/surfsense_backend/app/schemas/model_connections.py b/surfsense_backend/app/schemas/model_connections.py index 0b03c7fab..896532d6f 100644 --- a/surfsense_backend/app/schemas/model_connections.py +++ b/surfsense_backend/app/schemas/model_connections.py @@ -48,6 +48,32 @@ class ConnectionRead(BaseModel): model_config = ConfigDict(from_attributes=True) +class ModelSelection(BaseModel): + model_id: str = Field(..., max_length=255) + display_name: str | None = Field(None, max_length=255) + source: ModelSource | str = ModelSource.DISCOVERED + supports_chat: bool | None = None + max_input_tokens: int | None = None + supports_image_input: bool | None = None + supports_tools: bool | None = None + supports_image_generation: bool | None = None + enabled: bool = False + metadata: dict[str, Any] = Field(default_factory=dict) + + +class ModelPreviewRead(BaseModel): + model_id: str + display_name: str | None = None + source: ModelSource | str = ModelSource.DISCOVERED + supports_chat: bool | None = None + max_input_tokens: int | None = None + supports_image_input: bool | None = None + supports_tools: bool | None = None + supports_image_generation: bool | None = None + enabled: bool = False + metadata: dict[str, Any] = Field(default_factory=dict) + + class ConnectionCreate(BaseModel): provider: str = Field(..., max_length=100) base_url: str | None = Field(None, max_length=500) @@ -56,6 +82,7 @@ class ConnectionCreate(BaseModel): scope: ConnectionScope = ConnectionScope.SEARCH_SPACE search_space_id: int | None = None enabled: bool = True + models: list[ModelSelection] = Field(default_factory=list) class ConnectionUpdate(BaseModel): diff --git a/surfsense_backend/app/services/model_connection_service.py b/surfsense_backend/app/services/model_connection_service.py index 428af736e..7742e837e 100644 --- a/surfsense_backend/app/services/model_connection_service.py +++ b/surfsense_backend/app/services/model_connection_service.py @@ -8,6 +8,7 @@ from dataclasses import dataclass from datetime import UTC, datetime from typing import Any +import anyio import httpx import litellm @@ -292,6 +293,48 @@ def _litellm_static_models(conn: Connection) -> list[dict[str, Any]]: return results +async def _discover_bedrock_models(conn: Connection) -> list[dict[str, Any]]: + params = (conn.extra or {}).get("litellm_params", {}) + region_name = params.get("aws_region_name") + if not region_name: + return [] + + def list_models() -> list[dict[str, Any]]: + import boto3 + + client_kwargs: dict[str, str] = {"region_name": region_name} + if params.get("aws_access_key_id"): + client_kwargs["aws_access_key_id"] = params["aws_access_key_id"] + if params.get("aws_secret_access_key"): + client_kwargs["aws_secret_access_key"] = params["aws_secret_access_key"] + + client = boto3.client("bedrock", **client_kwargs) + response = client.list_foundation_models() + results: list[dict[str, Any]] = [] + for item in response.get("modelSummaries", []): + model_id = item.get("modelId") + if not model_id: + continue + input_modalities = set(item.get("inputModalities") or []) + output_modalities = set(item.get("outputModalities") or []) + results.append( + { + "model_id": model_id, + "display_name": item.get("modelName") or model_id, + "source": ModelSource.DISCOVERED, + "supports_chat": "TEXT" in input_modalities and "TEXT" in output_modalities, + "supports_image_input": "IMAGE" in input_modalities, + "supports_tools": None, + "supports_image_generation": "IMAGE" in output_modalities, + "max_input_tokens": None, + "metadata": item, + } + ) + return results + + return await anyio.to_thread.run_sync(list_models) + + async def discover_models(conn: Connection) -> list[dict[str, Any]]: allowlist = _allowlist(conn) spec = spec_for(conn.provider) @@ -304,6 +347,8 @@ async def discover_models(conn: Connection) -> list[dict[str, Any]]: results = await _discover_anthropic_models(conn) elif spec.discovery == "openai_models": results = await _discover_openai_shaped_models(conn, conn.base_url) + elif spec.discovery == "bedrock_models": + results = await _discover_bedrock_models(conn) elif spec.discovery == "static": results = _litellm_static_models(conn) else: diff --git a/surfsense_backend/app/services/provider_registry.py b/surfsense_backend/app/services/provider_registry.py index 871769f11..98bfb63c1 100644 --- a/surfsense_backend/app/services/provider_registry.py +++ b/surfsense_backend/app/services/provider_registry.py @@ -21,6 +21,7 @@ DiscoveryKind = Literal[ "ollama", "openai_models", "anthropic_models", + "bedrock_models", "openrouter", "static", "none", @@ -51,7 +52,7 @@ REGISTRY: dict[str, ProviderSpec] = { Transport.NATIVE, "vertex_ai", "static", None, False, "native" ), "bedrock": ProviderSpec( - Transport.NATIVE, "bedrock", "static", None, False, "native" + Transport.NATIVE, "bedrock", "bedrock_models", None, False, "native" ), "openrouter": ProviderSpec( Transport.OPENAI_COMPATIBLE, diff --git a/surfsense_web/atoms/model-connections/model-connections-mutation.atoms.ts b/surfsense_web/atoms/model-connections/model-connections-mutation.atoms.ts index fee3b95ba..ea91c6483 100644 --- a/surfsense_web/atoms/model-connections/model-connections-mutation.atoms.ts +++ b/surfsense_web/atoms/model-connections/model-connections-mutation.atoms.ts @@ -4,6 +4,7 @@ import type { ConnectionCreateRequest, ConnectionUpdateRequest, ModelCreateRequest, + ModelPreviewRead, ModelRead, ModelRoles, ModelsBulkUpdateRequest, @@ -103,6 +104,16 @@ export const discoverConnectionModelsMutationAtom = atomWithMutation((get) => { }; }); +export const previewConnectionModelsMutationAtom = atomWithMutation(() => { + return { + mutationKey: ["model-connections", "discover-preview"], + mutationFn: (request: ConnectionCreateRequest) => + modelConnectionsApiService.previewModels(request), + onSuccess: (_models: ModelPreviewRead[]) => {}, + onError: (error: Error) => toast.error(error.message || "Failed to discover models"), + }; +}); + export const addManualModelMutationAtom = atomWithMutation((get) => { const searchSpaceId = Number(get(activeSearchSpaceIdAtom)); return { diff --git a/surfsense_web/components/agent-action-log/action-log-dialog.tsx b/surfsense_web/components/agent-action-log/action-log-dialog.tsx index 1d0eefc17..5f3b83db1 100644 --- a/surfsense_web/components/agent-action-log/action-log-dialog.tsx +++ b/surfsense_web/components/agent-action-log/action-log-dialog.tsx @@ -2,7 +2,7 @@ import { useQueryClient } from "@tanstack/react-query"; import { useAtom, useAtomValue } from "jotai"; -import { RefreshCcw, Workflow } from "lucide-react"; +import { RefreshCw, Workflow } from "lucide-react"; import { useCallback } from "react"; import { actionLogDialogAtom } from "@/atoms/agent/action-log-dialog.atom"; import { agentFlagsAtom } from "@/atoms/agent/agent-flags-query.atom"; @@ -112,7 +112,7 @@ export function ActionLogDialog() { className="absolute right-14 top-4 size-8 rounded-full p-0 text-muted-foreground hover:bg-accent hover:text-accent-foreground" aria-label="Refresh action log" > - +
diff --git a/surfsense_web/components/settings/model-connections-settings.tsx b/surfsense_web/components/settings/model-connections-settings.tsx index 2e15ce2e9..6c3d1a411 100644 --- a/surfsense_web/components/settings/model-connections-settings.tsx +++ b/surfsense_web/components/settings/model-connections-settings.tsx @@ -4,12 +4,9 @@ import { useAtom, useAtomValue } from "jotai"; import { CheckCircle2, Trash2, XCircle } from "lucide-react"; import { useState } from "react"; import { - addManualModelMutationAtom, - bulkUpdateModelsMutationAtom, createModelConnectionMutationAtom, deleteModelConnectionMutationAtom, - discoverConnectionModelsMutationAtom, - updateModelMutationAtom, + previewConnectionModelsMutationAtom, updateModelRolesMutationAtom, } from "@/atoms/model-connections/model-connections-mutation.atoms"; import { @@ -41,9 +38,13 @@ import { SelectValue, } from "@/components/ui/select"; import { Separator } from "@/components/ui/separator"; -import type { ConnectionRead, ModelRead } from "@/contracts/types/model-connections.types"; +import type { + ConnectionRead, + ModelRead, + ModelSelection, +} from "@/contracts/types/model-connections.types"; import { ConnectionSettingsDialog } from "./model-connections/connection-settings-dialog"; -import { capability, modelLabel } from "./model-connections/model-utils"; +import { capability, modelLabel, type SelectableModel } from "./model-connections/model-utils"; import { ProviderConnectDialog } from "./model-connections/provider-connect-dialog"; import { type ConnectionDraft, @@ -154,16 +155,12 @@ export function ModelConnectionsSettings({ searchSpaceId }: { searchSpaceId: num const [{ data: providers = [] }] = useAtom(modelProvidersAtom); const [{ data: roles }] = useAtom(modelRolesAtom); const createConnection = useAtomValue(createModelConnectionMutationAtom); - const addManualModel = useAtomValue(addManualModelMutationAtom); - const discoverModels = useAtomValue(discoverConnectionModelsMutationAtom); - const updateModel = useAtomValue(updateModelMutationAtom); - const bulkUpdateModels = useAtomValue(bulkUpdateModelsMutationAtom); + const previewModels = useAtomValue(previewConnectionModelsMutationAtom); const updateRoles = useAtomValue(updateModelRolesMutationAtom); const [isAddProviderOpen, setIsAddProviderOpen] = useState(false); const [provider, setProvider] = useState("openai_compatible"); - const [connectedConnection, setConnectedConnection] = useState(null); - const [connectModels, setConnectModels] = useState([]); + const [connectModels, setConnectModels] = useState([]); const selectedProvider = providers.find((item) => item.provider === provider); const sortedProviders = [...providers].sort((left, right) => { @@ -185,7 +182,6 @@ export function ModelConnectionsSettings({ searchSpaceId }: { searchSpaceId: num const imageModels = enabledModels.filter((model) => capability(model, "image_gen")); function resetConnectState() { - setConnectedConnection(null); setConnectModels([]); } @@ -196,15 +192,48 @@ export function ModelConnectionsSettings({ searchSpaceId }: { searchSpaceId: num } } - function replaceConnectModels(updatedModels: ModelRead[]) { - setConnectModels((current) => - current.map((model) => updatedModels.find((updated) => updated.id === model.id) ?? model) - ); + function toModelSelection(model: SelectableModel): ModelSelection { + return { + model_id: model.model_id, + display_name: model.display_name, + source: model.source || "DISCOVERED", + supports_chat: model.supports_chat, + max_input_tokens: model.max_input_tokens, + supports_image_input: model.supports_image_input, + supports_tools: model.supports_tools, + supports_image_generation: model.supports_image_generation, + enabled: model.enabled, + metadata: "metadata" in model ? (model.metadata ?? {}) : (model.catalog ?? {}), + }; + } + + function mergePreviewModels(fetchedModels: SelectableModel[]) { + setConnectModels((current) => { + const currentById = new Map(current.map((model) => [model.model_id, model])); + return fetchedModels.map((model) => { + const prior = currentById.get(model.model_id); + return { + ...toModelSelection(model), + enabled: prior ? prior.enabled : model.enabled, + }; + }); + }); } // Each provider connect form builds its own credential payload; the backend // resolver (`to_litellm`) forwards `extra.litellm_params` straight to LiteLLM. function handleCreate(draft: ConnectionDraft) { + const models = [...connectModels]; + if (draft.seedModelId && !models.some((model) => model.model_id === draft.seedModelId)) { + models.push({ + model_id: draft.seedModelId, + display_name: draft.seedModelId, + source: "MANUAL", + enabled: true, + metadata: {}, + }); + } + createConnection.mutate( { provider, @@ -214,26 +243,12 @@ export function ModelConnectionsSettings({ searchSpaceId }: { searchSpaceId: num search_space_id: searchSpaceId, extra: draft.extra, enabled: true, + models, }, { - onSuccess: (created) => { - setConnectedConnection(created); - setConnectModels([]); - if (draft.seedModelId) { - addManualModel.mutate( - { - connectionId: created.id, - data: { model_id: draft.seedModelId }, - }, - { - onSuccess: (model) => setConnectModels([model]), - } - ); - } else { - discoverModels.mutate(created.id, { - onSuccess: (models) => setConnectModels(models), - }); - } + onSuccess: () => { + setIsAddProviderOpen(false); + resetConnectState(); }, } ); @@ -243,52 +258,72 @@ export function ModelConnectionsSettings({ searchSpaceId }: { searchSpaceId: num resetConnectState(); setProvider(providerId); setIsAddProviderOpen(true); + if (providerId === "vertex_ai") { + previewModels.mutate( + { + provider: providerId, + base_url: null, + api_key: null, + scope: "SEARCH_SPACE", + search_space_id: searchSpaceId, + extra: {}, + enabled: true, + models: [], + }, + { + onSuccess: mergePreviewModels, + } + ); + } } - function refreshConnectModels() { - if (!connectedConnection) return; - discoverModels.mutate(connectedConnection.id, { - onSuccess: (models) => setConnectModels(models), - }); + function refreshConnectModels(draft: ConnectionDraft) { + previewModels.mutate( + { + provider, + base_url: draft.base_url, + api_key: draft.api_key, + scope: "SEARCH_SPACE", + search_space_id: searchSpaceId, + extra: draft.extra, + enabled: true, + models: [], + }, + { + onSuccess: mergePreviewModels, + } + ); } function addConnectModel(modelId: string) { - if (!connectedConnection) return; - addManualModel.mutate( - { connectionId: connectedConnection.id, data: { model_id: modelId } }, - { - onSuccess: (model) => setConnectModels((current) => [...current, model]), - } + setConnectModels((current) => { + if (current.some((model) => model.model_id === modelId)) return current; + return [ + ...current, + { + model_id: modelId, + display_name: modelId, + source: "MANUAL", + enabled: true, + metadata: {}, + }, + ]; + }); + } + + function toggleConnectModel(model: SelectableModel, enabled: boolean) { + setConnectModels((current) => + current.map((item) => (item.model_id === model.model_id ? { ...item, enabled } : item)) ); } - function toggleConnectModel(model: ModelRead, enabled: boolean) { - updateModel.mutate( - { id: model.id, data: { enabled } }, - { - onSuccess: (updated) => replaceConnectModels([updated]), - } + function bulkToggleConnectModels(models: SelectableModel[], enabled: boolean) { + const modelIds = new Set(models.map((model) => model.model_id)); + setConnectModels((current) => + current.map((item) => (modelIds.has(item.model_id) ? { ...item, enabled } : item)) ); } - function bulkToggleConnectModels(models: ModelRead[], enabled: boolean) { - if (!connectedConnection) return; - bulkUpdateModels.mutate( - { - connectionId: connectedConnection.id, - data: { model_ids: models.map((model) => model.id), enabled }, - }, - { - onSuccess: replaceConnectModels, - } - ); - } - - function finishConnectFlow() { - setIsAddProviderOpen(false); - resetConnectState(); - } - function renderModelOption(model: ModelRead & { connectionName: string; provider: string }) { return ( @@ -347,17 +382,12 @@ export function ModelConnectionsSettings({ searchSpaceId }: { searchSpaceId: num selectedProvider={selectedProvider} isPending={createConnection.isPending} onSubmit={handleCreate} - connectedConnection={connectedConnection} - connectModels={connectModels} - isDiscoveringModels={discoverModels.isPending} - isAddingManualModel={addManualModel.isPending} - isUpdatingModel={updateModel.isPending} - isBulkUpdatingModels={bulkUpdateModels.isPending} - onRefreshModels={refreshConnectModels} - onAddManualModel={addConnectModel} - onToggleModel={toggleConnectModel} - onBulkToggleModels={bulkToggleConnectModels} - onDone={finishConnectFlow} + previewModels={connectModels} + isPreviewingModels={previewModels.isPending} + onPreviewModels={refreshConnectModels} + onAddPreviewModel={addConnectModel} + onTogglePreviewModel={toggleConnectModel} + onBulkTogglePreviewModels={bulkToggleConnectModels} /> {connections.length > 0 ? ( diff --git a/surfsense_web/components/settings/model-connections/azure-connect-form.tsx b/surfsense_web/components/settings/model-connections/azure-connect-form.tsx index 11a2e25d3..451f053db 100644 --- a/surfsense_web/components/settings/model-connections/azure-connect-form.tsx +++ b/surfsense_web/components/settings/model-connections/azure-connect-form.tsx @@ -1,7 +1,7 @@ -import { useState } from "react"; +import { useEffect, useState } from "react"; import { Input } from "@/components/ui/input"; import { Label } from "@/components/ui/label"; -import { ApiKeyField, ConnectFormFooter } from "./connect-fields"; +import { ApiKeyField } from "./connect-fields"; import { isValidAzureTargetUri, type ProviderConnectFormProps, @@ -12,48 +12,43 @@ import { * Azure OpenAI connect form. The user pastes a single Target URI, which we parse * into api base, api version, and the deployment name (seeded as the model). */ -export function AzureConnectForm({ isPending, onCancel, onSubmit }: ProviderConnectFormProps) { +export function AzureConnectForm({ onDraftChange }: ProviderConnectFormProps) { const [targetUri, setTargetUri] = useState(""); const [apiKey, setApiKey] = useState(""); const canSubmit = isValidAzureTargetUri(targetUri) && Boolean(apiKey.trim()); - function handleSubmit() { + useEffect(() => { const parsed = parseAzureTargetUri(targetUri); - onSubmit({ - base_url: parsed?.origin ?? null, - api_key: apiKey || null, - extra: parsed?.apiVersion ? { api_version: parsed.apiVersion } : {}, - seedModelId: parsed?.deploymentName || undefined, - }); - } + onDraftChange( + { + base_url: parsed?.origin ?? null, + api_key: apiKey || null, + extra: parsed?.apiVersion ? { api_version: parsed.apiVersion } : {}, + seedModelId: parsed?.deploymentName || undefined, + }, + canSubmit + ); + }, [apiKey, canSubmit, onDraftChange, targetUri]); return ( - <> -
-
- - setTargetUri(event.target.value)} - placeholder="https://your-resource.cognitiveservices.azure.com/openai/deployments/deployment-name/chat/completions?api-version=2025-01-01-preview" - /> -

- Paste your endpoint target URI from Azure OpenAI (including API base, deployment name, - and API version). -

-
- +
+ + setTargetUri(event.target.value)} + placeholder="https://your-resource.cognitiveservices.azure.com/openai/deployments/deployment-name/chat/completions?api-version=2025-01-01-preview" /> +

+ Paste your endpoint target URI from Azure OpenAI (including API base, deployment name, and + API version). +

- - +
); } diff --git a/surfsense_web/components/settings/model-connections/bedrock-connect-form.tsx b/surfsense_web/components/settings/model-connections/bedrock-connect-form.tsx index 9466c7cd1..3115ac223 100644 --- a/surfsense_web/components/settings/model-connections/bedrock-connect-form.tsx +++ b/surfsense_web/components/settings/model-connections/bedrock-connect-form.tsx @@ -1,4 +1,4 @@ -import { useState } from "react"; +import { useEffect, useState } from "react"; import { Input } from "@/components/ui/input"; import { Label } from "@/components/ui/label"; import { @@ -8,7 +8,7 @@ import { SelectTrigger, SelectValue, } from "@/components/ui/select"; -import { ConnectFormFooter } from "./connect-fields"; +import { ApiKeyField } from "./connect-fields"; import { AWS_REGION_OPTIONS, BEDROCK_AUTH_ACCESS_KEY, @@ -21,7 +21,7 @@ import { * Amazon Bedrock connect form. Region + auth method drive which AWS credentials * are collected; everything rides along in `extra.litellm_params`. */ -export function BedrockConnectForm({ isPending, onCancel, onSubmit }: ProviderConnectFormProps) { +export function BedrockConnectForm({ onDraftChange }: ProviderConnectFormProps) { const [region, setRegion] = useState(""); const [authMethod, setAuthMethod] = useState(BEDROCK_AUTH_ACCESS_KEY); const [accessKeyId, setAccessKeyId] = useState(""); @@ -39,7 +39,7 @@ export function BedrockConnectForm({ isPending, onCancel, onSubmit }: ProviderCo return true; })(); - function handleSubmit() { + useEffect(() => { const params: Record = { aws_region_name: region }; if (authMethod === BEDROCK_AUTH_ACCESS_KEY) { params.aws_access_key_id = accessKeyId; @@ -47,88 +47,74 @@ export function BedrockConnectForm({ isPending, onCancel, onSubmit }: ProviderCo } else if (authMethod === BEDROCK_AUTH_LONG_TERM_API_KEY) { params.aws_bearer_token_bedrock = bearerToken; } - onSubmit({ base_url: null, api_key: null, extra: { litellm_params: params } }); - } + onDraftChange({ base_url: null, api_key: null, extra: { litellm_params: params } }, canSubmit); + }, [accessKeyId, authMethod, bearerToken, canSubmit, onDraftChange, region, secretAccessKey]); return ( - <> -
-
- - -
-
- - -
- {authMethod === BEDROCK_AUTH_ACCESS_KEY ? ( - <> -
- - setAccessKeyId(event.target.value)} - placeholder="AKIAIOSFODNN7EXAMPLE" - /> -
-
- - setSecretAccessKey(event.target.value)} - placeholder="wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY" - type="password" - /> -
- - ) : null} - {authMethod === BEDROCK_AUTH_LONG_TERM_API_KEY ? ( +
+
+ + +
+
+ + +
+ {authMethod === BEDROCK_AUTH_ACCESS_KEY ? ( + <>
- + setBearerToken(event.target.value)} - placeholder="Your long-term API key" - type="password" + value={accessKeyId} + onChange={(event) => setAccessKeyId(event.target.value)} + placeholder="AKIAIOSFODNN7EXAMPLE" />
- ) : null} - {authMethod === BEDROCK_AUTH_IAM ? ( -

- SurfSense will use the IAM role attached to the environment it's running in to - authenticate. -

- ) : null} + + + ) : null} + {authMethod === BEDROCK_AUTH_LONG_TERM_API_KEY ? ( + + ) : null} + {authMethod === BEDROCK_AUTH_IAM ? (

- Add Bedrock model IDs from the provider's settings after connecting. + SurfSense will use the IAM role attached to the environment it's running in to + authenticate.

-
- - + ) : null} +

+ Add Bedrock model IDs from the provider's settings after connecting. +

+
); } diff --git a/surfsense_web/components/settings/model-connections/connect-fields.tsx b/surfsense_web/components/settings/model-connections/connect-fields.tsx index af8db7f12..44b2d434f 100644 --- a/surfsense_web/components/settings/model-connections/connect-fields.tsx +++ b/surfsense_web/components/settings/model-connections/connect-fields.tsx @@ -1,3 +1,5 @@ +import { Eye, EyeOff } from "lucide-react"; +import { useState } from "react"; import { Button } from "@/components/ui/button"; import { DialogFooter } from "@/components/ui/dialog"; import { Input } from "@/components/ui/input"; @@ -43,15 +45,31 @@ export function ApiKeyField({ label = "API Key", placeholder = "API key", }: ApiKeyFieldProps) { + const [showApiKey, setShowApiKey] = useState(false); + return (
- onChange(event.target.value)} - placeholder={placeholder} - type="password" - /> +
+ onChange(event.target.value)} + placeholder={placeholder} + type={showApiKey ? "text" : "password"} + className="pr-11" + /> + +
); } @@ -71,7 +89,7 @@ export function ConnectFormFooter({ isPending, }: ConnectFormFooterProps) { return ( - + diff --git a/surfsense_web/components/settings/model-connections/connection-settings-dialog.tsx b/surfsense_web/components/settings/model-connections/connection-settings-dialog.tsx index f3821af46..d0f8e6c16 100644 --- a/surfsense_web/components/settings/model-connections/connection-settings-dialog.tsx +++ b/surfsense_web/components/settings/model-connections/connection-settings-dialog.tsx @@ -25,8 +25,8 @@ import { Separator } from "@/components/ui/separator"; import type { ConnectionRead, ConnectionUpdateRequest, - ModelRead, } from "@/contracts/types/model-connections.types"; +import type { SelectableModel } from "./model-utils"; import { ModelsSelectionPanel } from "./models-selection-panel"; import { providerIcon } from "./provider-metadata"; @@ -101,17 +101,22 @@ export function ConnectionSettingsDialog({ }); } - function handleToggleModel(model: ModelRead, enabled: boolean) { + function handleToggleModel(model: SelectableModel, enabled: boolean) { + if (typeof model.id !== "number") return; updateModel.mutate({ id: model.id, data: { enabled }, }); } - function handleBulkToggle(models: ModelRead[], enabled: boolean) { + function handleBulkToggle(models: SelectableModel[], enabled: boolean) { + const modelIds = models + .map((model) => model.id) + .filter((id): id is number => typeof id === "number"); + if (modelIds.length === 0) return; bulkUpdateModels.mutate({ connectionId: connection.id, - data: { model_ids: models.map((model) => model.id), enabled }, + data: { model_ids: modelIds, enabled }, }); } @@ -184,12 +189,7 @@ export function ConnectionSettingsDialog({ onChange={(event) => setAllowlistText(event.target.value)} placeholder="Comma-separated, e.g. anthropic/claude-sonnet-4-5, google/gemini-2.5-pro" /> -
diff --git a/surfsense_web/components/settings/model-connections/default-connect-form.tsx b/surfsense_web/components/settings/model-connections/default-connect-form.tsx index 3f261c6b2..768c0b5da 100644 --- a/surfsense_web/components/settings/model-connections/default-connect-form.tsx +++ b/surfsense_web/components/settings/model-connections/default-connect-form.tsx @@ -1,5 +1,5 @@ -import { useState } from "react"; -import { ApiBaseUrlField, ApiKeyField, ConnectFormFooter } from "./connect-fields"; +import { useEffect, useState } from "react"; +import { ApiBaseUrlField, ApiKeyField } from "./connect-fields"; import type { ProviderConnectFormProps } from "./provider-metadata"; /** @@ -11,41 +11,31 @@ export function DefaultConnectForm({ provider, defaultBaseUrl, baseUrlRequired, - isPending, - onCancel, - onSubmit, + onDraftChange, }: ProviderConnectFormProps) { const [baseUrl, setBaseUrl] = useState(defaultBaseUrl); const [apiKey, setApiKey] = useState(""); const isOllama = provider === "ollama_chat"; const canSubmit = !(baseUrlRequired && !baseUrl.trim()); - function handleSubmit() { - onSubmit({ base_url: baseUrl || null, api_key: apiKey || null, extra: {} }); - } + useEffect(() => { + onDraftChange({ base_url: baseUrl || null, api_key: apiKey || null, extra: {} }, canSubmit); + }, [apiKey, baseUrl, canSubmit, onDraftChange]); return ( - <> -
- - -
- + - + +
); } diff --git a/surfsense_web/components/settings/model-connections/model-utils.ts b/surfsense_web/components/settings/model-connections/model-utils.ts index 1db14b3eb..2887f2179 100644 --- a/surfsense_web/components/settings/model-connections/model-utils.ts +++ b/surfsense_web/components/settings/model-connections/model-utils.ts @@ -1,4 +1,4 @@ -import type { ModelRead } from "@/contracts/types/model-connections.types"; +import type { ModelPreviewRead, ModelRead } from "@/contracts/types/model-connections.types"; export type ModelCapabilityFilter = "chat" | "vision" | "image_gen"; @@ -8,17 +8,22 @@ export const MODEL_CAPABILITY_FILTERS: { key: ModelCapabilityFilter; label: stri { key: "image_gen", label: "Image" }, ]; -export function modelLabel(model: ModelRead) { +export type SelectableModel = (ModelRead | ModelPreviewRead) & { + id?: number | string; + connection_id?: number; +}; + +export function modelLabel(model: SelectableModel) { return model.display_name || model.model_id; } -export function capability(model: ModelRead, key: ModelCapabilityFilter) { +export function capability(model: SelectableModel, key: ModelCapabilityFilter) { if (key === "chat") return Boolean(model.supports_chat); if (key === "vision") return Boolean(model.supports_image_input); return Boolean(model.supports_image_generation); } -export function capabilityLabels(model: ModelRead) { +export function capabilityLabels(model: SelectableModel) { return MODEL_CAPABILITY_FILTERS.filter((filter) => capability(model, filter.key)) .map((filter) => filter.label.toLowerCase()) .join(", "); diff --git a/surfsense_web/components/settings/model-connections/models-selection-panel.tsx b/surfsense_web/components/settings/model-connections/models-selection-panel.tsx index 20fbe862f..01ff0d1e7 100644 --- a/surfsense_web/components/settings/model-connections/models-selection-panel.tsx +++ b/surfsense_web/components/settings/model-connections/models-selection-panel.tsx @@ -4,17 +4,17 @@ import { Badge } from "@/components/ui/badge"; import { Button } from "@/components/ui/button"; import { Checkbox } from "@/components/ui/checkbox"; import { Input } from "@/components/ui/input"; -import type { ModelRead } from "@/contracts/types/model-connections.types"; import { capability, capabilityLabels, MODEL_CAPABILITY_FILTERS, type ModelCapabilityFilter, modelLabel, + type SelectableModel, } from "./model-utils"; interface ModelsSelectionPanelProps { - models: ModelRead[]; + models: SelectableModel[]; description?: string; emptyMessage?: string; manualInputPlaceholder?: string; @@ -25,8 +25,8 @@ interface ModelsSelectionPanelProps { isBulkUpdating?: boolean; onRefresh?: () => void; onAddManual?: (modelId: string) => void; - onToggleModel?: (model: ModelRead, enabled: boolean) => void; - onBulkToggle?: (models: ModelRead[], enabled: boolean) => void; + onToggleModel?: (model: SelectableModel, enabled: boolean) => void; + onBulkToggle?: (models: SelectableModel[], enabled: boolean) => void; } export function ModelsSelectionPanel({ @@ -166,7 +166,7 @@ export function ModelsSelectionPanel({
{filteredModels.map((model) => (
void; - connectedConnection?: ConnectionRead | null; - connectModels?: ModelRead[]; - isDiscoveringModels?: boolean; - isAddingManualModel?: boolean; - isUpdatingModel?: boolean; - isBulkUpdatingModels?: boolean; - onRefreshModels?: () => void; - onAddManualModel?: (modelId: string) => void; - onToggleModel?: (model: ModelRead, enabled: boolean) => void; - onBulkToggleModels?: (models: ModelRead[], enabled: boolean) => void; - onDone?: () => void; + previewModels?: SelectableModel[]; + isPreviewingModels?: boolean; + onPreviewModels?: (draft: ConnectionDraft) => void; + onAddPreviewModel?: (modelId: string) => void; + onTogglePreviewModel?: (model: SelectableModel, enabled: boolean) => void; + onBulkTogglePreviewModels?: (models: SelectableModel[], enabled: boolean) => void; } /** @@ -57,97 +50,93 @@ export function ProviderConnectDialog({ selectedProvider, isPending, onSubmit, - connectedConnection, - connectModels = [], - isDiscoveringModels = false, - isAddingManualModel = false, - isUpdatingModel = false, - isBulkUpdatingModels = false, - onRefreshModels, - onAddManualModel, - onToggleModel, - onBulkToggleModels, - onDone, + previewModels = [], + isPreviewingModels = false, + onPreviewModels, + onAddPreviewModel, + onTogglePreviewModel, + onBulkTogglePreviewModels, }: ProviderConnectDialogProps) { const meta = providerDisplay(provider); - const isModelSelectionStep = Boolean(connectedConnection); + const isAzure = provider === "azure"; + const isBedrock = provider === "bedrock"; + const isVertex = provider === "vertex_ai"; + const [currentDraft, setCurrentDraft] = useState({ + base_url: null, + api_key: null, + extra: {}, + }); + const [canSubmit, setCanSubmit] = useState(false); + + const handleDraftChange = useCallback((draft: ConnectionDraft, nextCanSubmit: boolean) => { + setCurrentDraft(draft); + setCanSubmit(nextCanSubmit); + }, []); const formProps: ProviderConnectFormProps = { provider, defaultBaseUrl: providerDefaultBaseUrl(provider, selectedProvider?.default_base_url), baseUrlRequired: Boolean(selectedProvider?.base_url_required), - isPending, - onCancel: () => onOpenChange(false), - onSubmit, + onDraftChange: handleDraftChange, }; + const modelDescription = (() => { + if (isAzure) { + return "Select the models to enable for Azure OpenAI"; + } + if (isBedrock) { + return "Select the models to enable for Amazon Bedrock"; + } + if (isVertex) { + return "Select the models to enable for Gemini"; + } + return "Select the models to enable for this provider"; + })(); + + const canRefreshModels = !isAzure && !isVertex && (!isBedrock || canSubmit); + return ( - +
{providerIcon(provider, "size-5")}
- - {isModelSelectionStep ? `Select ${meta.name} models` : `Connect ${meta.name}`} - - - {isModelSelectionStep - ? selectedProvider?.discovery === "static" - ? "Choose from known model IDs or add one manually." - : "Choose which discovered models should be available in this search space." - : meta.subtitle} - + Connect {meta.name} + {meta.subtitle}
- {isModelSelectionStep ? ( - <> -
- -
- - - - - ) : ( -
- {provider === "azure" ? ( - - ) : provider === "bedrock" ? ( - - ) : provider === "vertex_ai" ? ( - - ) : ( - - )} -
- )} +
+ {provider === "azure" ? ( + + ) : provider === "bedrock" ? ( + + ) : provider === "vertex_ai" ? ( + + ) : ( + + )} + + + + onPreviewModels?.(currentDraft) : undefined} + onAddManual={onAddPreviewModel} + onToggleModel={onTogglePreviewModel} + onBulkToggle={onBulkTogglePreviewModels} + /> +
+ onOpenChange(false)} + onSubmit={() => onSubmit(currentDraft)} + canSubmit={canSubmit} + isPending={isPending} + />
); diff --git a/surfsense_web/components/settings/model-connections/provider-metadata.tsx b/surfsense_web/components/settings/model-connections/provider-metadata.tsx index 0ca8ae419..73e873393 100644 --- a/surfsense_web/components/settings/model-connections/provider-metadata.tsx +++ b/surfsense_web/components/settings/model-connections/provider-metadata.tsx @@ -133,7 +133,5 @@ export interface ProviderConnectFormProps { provider: string; defaultBaseUrl: string; baseUrlRequired: boolean; - isPending: boolean; - onCancel: () => void; - onSubmit: (draft: ConnectionDraft) => void; + onDraftChange: (draft: ConnectionDraft, canSubmit: boolean) => void; } diff --git a/surfsense_web/components/settings/model-connections/vertex-connect-form.tsx b/surfsense_web/components/settings/model-connections/vertex-connect-form.tsx index 096d3df2e..1027742bc 100644 --- a/surfsense_web/components/settings/model-connections/vertex-connect-form.tsx +++ b/surfsense_web/components/settings/model-connections/vertex-connect-form.tsx @@ -1,4 +1,4 @@ -import { useState } from "react"; +import { useEffect, useState } from "react"; import { Input } from "@/components/ui/input"; import { Label } from "@/components/ui/label"; import { @@ -8,7 +8,6 @@ import { SelectTrigger, SelectValue, } from "@/components/ui/select"; -import { ConnectFormFooter } from "./connect-fields"; import { type ProviderConnectFormProps, VERTEX_AUTH_SERVICE_ACCOUNT, @@ -21,7 +20,7 @@ import { * credentials JSON file (read into a string); workload identity collects a * project id. Credentials ride along in `extra.litellm_params`. */ -export function VertexConnectForm({ isPending, onCancel, onSubmit }: ProviderConnectFormProps) { +export function VertexConnectForm({ onDraftChange }: ProviderConnectFormProps) { const [authMethod, setAuthMethod] = useState(VERTEX_AUTH_SERVICE_ACCOUNT); const [location, setLocation] = useState(VERTEX_DEFAULT_LOCATION); const [credentials, setCredentials] = useState(""); @@ -35,7 +34,7 @@ export function VertexConnectForm({ isPending, onCancel, onSubmit }: ProviderCon setCredentials(await file.text()); } - function handleSubmit() { + useEffect(() => { const params: Record = {}; if (location) params.vertex_location = location; if (authMethod === VERTEX_AUTH_SERVICE_ACCOUNT) { @@ -43,85 +42,77 @@ export function VertexConnectForm({ isPending, onCancel, onSubmit }: ProviderCon } else if (project) { params.vertex_project = project; } - onSubmit({ base_url: null, api_key: null, extra: { litellm_params: params } }); - } + onDraftChange({ base_url: null, api_key: null, extra: { litellm_params: params } }, canSubmit); + }, [authMethod, canSubmit, credentials, location, onDraftChange, project]); return ( - <> -
-
- - -
-
- - setLocation(event.target.value)} - placeholder={VERTEX_DEFAULT_LOCATION} - /> -

- Region where your Google Vertex AI models are hosted. -

-
- {authMethod === VERTEX_AUTH_SERVICE_ACCOUNT ? ( -
- - handleCredentialsFile(event.target.files?.[0])} - /> - -

- {credentials - ? "Credentials file loaded." - : "Attach your service account key JSON from Google Cloud."} -

-
- ) : ( -
- - setProject(event.target.value)} - placeholder="my-vertex-project" - /> -

- The GCP project where Vertex AI is enabled. -

-
- )} +
+
+ + +
+
+ + setLocation(event.target.value)} + placeholder={VERTEX_DEFAULT_LOCATION} + />

- Add Vertex AI model IDs from the provider's settings after connecting. + Region where your Google Vertex AI models are hosted.

- - + {authMethod === VERTEX_AUTH_SERVICE_ACCOUNT ? ( +
+ + handleCredentialsFile(event.target.files?.[0])} + /> + +

+ {credentials + ? "Credentials file loaded." + : "Attach your service account key JSON from Google Cloud."} +

+
+ ) : ( +
+ + setProject(event.target.value)} + placeholder="my-vertex-project" + /> +

+ The GCP project where Vertex AI is enabled. +

+
+ )} +

+ Add Vertex AI model IDs from the provider's settings after connecting. +

+
); } diff --git a/surfsense_web/contracts/types/model-connections.types.ts b/surfsense_web/contracts/types/model-connections.types.ts index c75f4c90a..134c740b2 100644 --- a/surfsense_web/contracts/types/model-connections.types.ts +++ b/surfsense_web/contracts/types/model-connections.types.ts @@ -40,6 +40,21 @@ export const connectionRead = z.object({ created_at: z.string().nullable().optional(), }); +export const modelSelection = z.object({ + model_id: z.string().min(1), + display_name: z.string().nullable().optional(), + source: z.union([modelSourceEnum, z.string()]).default("DISCOVERED"), + supports_chat: z.boolean().nullable().optional(), + max_input_tokens: z.number().nullable().optional(), + supports_image_input: z.boolean().nullable().optional(), + supports_tools: z.boolean().nullable().optional(), + supports_image_generation: z.boolean().nullable().optional(), + enabled: z.boolean().default(false), + metadata: z.record(z.string(), z.any()).default({}), +}); + +export const modelPreviewRead = modelSelection; + export const connectionCreateRequest = z.object({ provider: z.string().min(1), base_url: z.string().nullable().optional(), @@ -48,6 +63,7 @@ export const connectionCreateRequest = z.object({ scope: connectionScopeEnum.default("SEARCH_SPACE"), search_space_id: z.number().nullable().optional(), enabled: z.boolean().default(true), + models: z.array(modelSelection).default([]), }); export const connectionUpdateRequest = z.object({ @@ -105,9 +121,12 @@ export const modelProviderListResponse = z.array(modelProviderRead); export const connectionListResponse = z.array(connectionRead); export const modelListResponse = z.array(modelRead); +export const modelPreviewListResponse = z.array(modelPreviewRead); export type ConnectionScope = z.infer; export type ModelRead = z.infer; +export type ModelPreviewRead = z.infer; +export type ModelSelection = z.infer; export type ConnectionRead = z.infer; export type ConnectionCreateRequest = z.infer; export type ConnectionUpdateRequest = z.infer; diff --git a/surfsense_web/lib/apis/model-connections-api.service.ts b/surfsense_web/lib/apis/model-connections-api.service.ts index bd5aa1309..f463a27e7 100644 --- a/surfsense_web/lib/apis/model-connections-api.service.ts +++ b/surfsense_web/lib/apis/model-connections-api.service.ts @@ -7,6 +7,7 @@ import { connectionRead, connectionUpdateRequest, type ModelCreateRequest, + type ModelPreviewRead, type ModelProviderRead, type ModelRead, type ModelRoles, @@ -14,6 +15,7 @@ import { type ModelUpdateRequest, modelCreateRequest, modelListResponse, + modelPreviewListResponse, modelProviderListResponse, modelRead, modelRoles, @@ -76,6 +78,20 @@ class ModelConnectionsApiService { return baseApiService.post(`/api/v1/model-connections/${id}/discover`, modelListResponse); }; + previewModels = async (request: ConnectionCreateRequest): Promise => { + const parsed = connectionCreateRequest.safeParse(request); + if (!parsed.success) { + throw new ValidationError(parsed.error.issues.map((issue) => issue.message).join(", ")); + } + return baseApiService.post( + `/api/v1/model-connections/discover-preview`, + modelPreviewListResponse, + { + body: parsed.data, + } + ); + }; + addManualModel = async ( connectionId: number, request: ModelCreateRequest From 55f004e1da4bf6297aa8dcf3215ab0ca825c8051 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Fri, 12 Jun 2026 22:50:50 +0530 Subject: [PATCH 32/59] feat(model-connections): improve model discovery error handling and enhance UI components --- .../app/routes/model_connections_routes.py | 11 +++- .../app/services/model_connection_service.py | 53 ++++++++++++++----- .../models-selection-panel.tsx | 15 +++--- 3 files changed, 55 insertions(+), 24 deletions(-) diff --git a/surfsense_backend/app/routes/model_connections_routes.py b/surfsense_backend/app/routes/model_connections_routes.py index 2405843a7..474d376d3 100644 --- a/surfsense_backend/app/routes/model_connections_routes.py +++ b/surfsense_backend/app/routes/model_connections_routes.py @@ -32,6 +32,7 @@ from app.schemas import ( VerifyConnectionResponse, ) from app.services.model_connection_service import ( + ModelDiscoveryError, derive_capabilities, discover_models, persist_verification, @@ -313,7 +314,10 @@ async def preview_connection_models( search_space_id=data.search_space_id if data.scope == ConnectionScope.SEARCH_SPACE else None, user_id=user.id, ) - discovered = await discover_models(draft) + try: + discovered = await discover_models(draft) + except ModelDiscoveryError as exc: + raise HTTPException(status_code=400, detail=str(exc)) from exc return [_preview_model_read(item) for item in discovered] @@ -367,7 +371,10 @@ async def discover_connection_models( ): conn = await _load_connection(session, connection_id) await _assert_connection_access(session, user, conn, Permission.LLM_CONFIGS_CREATE.value) - discovered = await discover_models(conn) + try: + discovered = await discover_models(conn) + except ModelDiscoveryError as exc: + raise HTTPException(status_code=400, detail=str(exc)) from exc by_model_id = {model.model_id: model for model in conn.models} for item in discovered: db_model = by_model_id.get(item["model_id"]) diff --git a/surfsense_backend/app/services/model_connection_service.py b/surfsense_backend/app/services/model_connection_service.py index 7742e837e..c9ee2779f 100644 --- a/surfsense_backend/app/services/model_connection_service.py +++ b/surfsense_backend/app/services/model_connection_service.py @@ -31,6 +31,10 @@ class VerifyResult: message: str = "" +class ModelDiscoveryError(Exception): + """User-correctable discovery failure for provider configuration issues.""" + + def _auth_headers(conn: Connection) -> dict[str, str]: if not conn.api_key: return {} @@ -120,6 +124,23 @@ async def persist_verification(conn: Connection) -> VerifyResult: return result +def _discovery_error_message(conn: Connection, exc: httpx.HTTPError) -> str: + base_url = _base_url_or_default(conn) + if isinstance(exc, httpx.HTTPStatusError): + status_code = exc.response.status_code + if status_code in (401, 403): + return "Authentication failed while discovering models." + if status_code == 404: + spec = spec_for(conn.provider) + if spec.transport == Transport.OPENAI_COMPATIBLE: + return "OpenAI-compatible servers should expose /v1/models." + return "Model discovery endpoint returned 404." + return f"Model discovery failed with HTTP {status_code}." + if isinstance(exc, httpx.TimeoutException): + return f"Model discovery timed out: {exc}" + return _docker_hint(base_url, exc) + + def _allowlist(conn: Connection) -> set[str]: raw = (conn.extra or {}).get("model_ids") or [] return {str(item).strip() for item in raw if str(item).strip()} @@ -339,20 +360,23 @@ async def discover_models(conn: Connection) -> list[dict[str, Any]]: allowlist = _allowlist(conn) spec = spec_for(conn.provider) - if spec.discovery == "ollama": - results = await _ollama_tags_then_show(conn) - elif spec.discovery == "openrouter": - results = await _openrouter_models(conn) - elif spec.discovery == "anthropic_models": - results = await _discover_anthropic_models(conn) - elif spec.discovery == "openai_models": - results = await _discover_openai_shaped_models(conn, conn.base_url) - elif spec.discovery == "bedrock_models": - results = await _discover_bedrock_models(conn) - elif spec.discovery == "static": - results = _litellm_static_models(conn) - else: - results = [] + try: + if spec.discovery == "ollama": + results = await _ollama_tags_then_show(conn) + elif spec.discovery == "openrouter": + results = await _openrouter_models(conn) + elif spec.discovery == "anthropic_models": + results = await _discover_anthropic_models(conn) + elif spec.discovery == "openai_models": + results = await _discover_openai_shaped_models(conn, conn.base_url) + elif spec.discovery == "bedrock_models": + results = await _discover_bedrock_models(conn) + elif spec.discovery == "static": + results = _litellm_static_models(conn) + else: + results = [] + except httpx.HTTPError as exc: + raise ModelDiscoveryError(_discovery_error_message(conn, exc)) from exc if allowlist: results = [item for item in results if item["model_id"] in allowlist] @@ -376,6 +400,7 @@ async def test_model(conn: Connection, model: Model) -> VerifyResult: __all__ = [ + "ModelDiscoveryError", "VerifyResult", "derive_capabilities", "discover_models", diff --git a/surfsense_web/components/settings/model-connections/models-selection-panel.tsx b/surfsense_web/components/settings/model-connections/models-selection-panel.tsx index 01ff0d1e7..573049f6c 100644 --- a/surfsense_web/components/settings/model-connections/models-selection-panel.tsx +++ b/surfsense_web/components/settings/model-connections/models-selection-panel.tsx @@ -1,4 +1,4 @@ -import { RefreshCcw } from "lucide-react"; +import { RefreshCw } from "lucide-react"; import { useState } from "react"; import { Badge } from "@/components/ui/badge"; import { Button } from "@/components/ui/button"; @@ -32,7 +32,7 @@ interface ModelsSelectionPanelProps { export function ModelsSelectionPanel({ models, description = "Select models to make available for this provider.", - emptyMessage = "No models yet. Use the refresh button to discover models or add one manually.", + emptyMessage = "No models available.", manualInputPlaceholder = "Add a model ID manually", refreshLabel = "Refresh models", isRefreshing = false, @@ -86,14 +86,14 @@ export function ModelsSelectionPanel({ {onRefresh ? ( ) : null}
@@ -113,7 +113,6 @@ export function ModelsSelectionPanel({ placeholder={manualInputPlaceholder} />
) : null} -
+
{models.length === 0 ? (
{emptyMessage} From 9f6210ad089788304c1a2bbb1068e505cefe18c3 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Sat, 13 Jun 2026 00:12:04 +0530 Subject: [PATCH 33/59] feat(model-connections): add test preview functionality for model connections --- .../app/routes/model_connections_routes.py | 46 +++++++++- surfsense_backend/app/schemas/__init__.py | 20 ++++- .../app/schemas/model_connections.py | 4 + .../app/services/model_connection_service.py | 87 +++++++++++++++++-- .../app/services/model_resolver.py | 2 + .../app/services/provider_registry.py | 27 ++++-- .../model-connections-mutation.atoms.ts | 15 ++++ .../settings/model-connections-settings.tsx | 82 +++++++++-------- .../connection-settings-dialog.tsx | 63 +++++++++----- .../provider-connect-dialog.tsx | 4 +- .../types/model-connections.types.ts | 5 ++ .../lib/apis/model-connections-api.service.ts | 16 ++++ 12 files changed, 294 insertions(+), 77 deletions(-) diff --git a/surfsense_backend/app/routes/model_connections_routes.py b/surfsense_backend/app/routes/model_connections_routes.py index 474d376d3..76e4a3dfb 100644 --- a/surfsense_backend/app/routes/model_connections_routes.py +++ b/surfsense_backend/app/routes/model_connections_routes.py @@ -26,11 +26,13 @@ from app.schemas import ( ModelRead, ModelRolesRead, ModelRolesUpdate, - ModelSelection, ModelsBulkUpdate, + ModelSelection, + ModelTestPreview, ModelUpdate, VerifyConnectionResponse, ) +from app.services.model_capabilities import has_capability from app.services.model_connection_service import ( ModelDiscoveryError, derive_capabilities, @@ -38,7 +40,6 @@ from app.services.model_connection_service import ( persist_verification, test_model, ) -from app.services.model_capabilities import has_capability from app.services.provider_registry import REGISTRY from app.users import current_active_user from app.utils.rbac import check_permission @@ -321,6 +322,47 @@ async def preview_connection_models( return [_preview_model_read(item) for item in discovered] +@router.post("/model-connections/test-preview", response_model=VerifyConnectionResponse) +async def test_preview_connection_model( + data: ModelTestPreview, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +): + if data.scope == ConnectionScope.SEARCH_SPACE and data.search_space_id is not None: + await check_permission( + session, + user, + data.search_space_id, + Permission.LLM_CONFIGS_CREATE.value, + "You don't have permission to create model connections in this search space", + ) + + model_id = data.model_id.strip() + if not model_id: + raise HTTPException(status_code=400, detail="model_id is required") + + draft = Connection( + provider=data.provider, + base_url=data.base_url, + api_key=data.api_key, + extra=data.extra or {}, + scope=data.scope, + enabled=data.enabled, + search_space_id=data.search_space_id if data.scope == ConnectionScope.SEARCH_SPACE else None, + user_id=user.id, + ) + model = Model( + connection_id=0, + model_id=model_id, + source=ModelSource.MANUAL, + enabled=True, + capabilities_override={}, + catalog={}, + ) + result = await test_model(draft, model) + return VerifyConnectionResponse(status=result.status, ok=result.ok, message=result.message) + + @router.put("/model-connections/{connection_id}", response_model=ConnectionRead) async def update_connection( connection_id: int, diff --git a/surfsense_backend/app/schemas/__init__.py b/surfsense_backend/app/schemas/__init__.py index efa448dcd..3c4fdfa83 100644 --- a/surfsense_backend/app/schemas/__init__.py +++ b/surfsense_backend/app/schemas/__init__.py @@ -54,8 +54,9 @@ from .model_connections import ( ModelRead, ModelRolesRead, ModelRolesUpdate, - ModelSelection, ModelsBulkUpdate, + ModelSelection, + ModelTestPreview, ModelUpdate, VerifyConnectionResponse, ) @@ -149,7 +150,7 @@ from .vision_llm import ( VisionLLMConfigUpdate, ) -__all__ = [ +__all__ = [ # Folder schemas "BulkDocumentMove", # Chat schemas (assistant-ui integration) @@ -159,6 +160,10 @@ __all__ = [ "ChunkCreate", "ChunkRead", "ChunkUpdate", + # Model connection schemas + "ConnectionCreate", + "ConnectionRead", + "ConnectionUpdate", "CreateCreditCheckoutSessionRequest", "CreateCreditCheckoutSessionResponse", "CreditPurchaseHistoryResponse", @@ -232,6 +237,16 @@ __all__ = [ "MembershipRead", "MembershipReadWithUser", "MembershipUpdate", + "ModelCreate", + "ModelPreviewRead", + "ModelProviderRead", + "ModelRead", + "ModelRolesRead", + "ModelRolesUpdate", + "ModelSelection", + "ModelTestPreview", + "ModelUpdate", + "ModelsBulkUpdate", "NewChatMessageAppend", "NewChatMessageCreate", "NewChatMessageRead", @@ -282,6 +297,7 @@ __all__ = [ "UserRead", "UserSearchSpaceAccess", "UserUpdate", + "VerifyConnectionResponse", # Video Presentation schemas "VideoPresentationBase", "VideoPresentationCreate", diff --git a/surfsense_backend/app/schemas/model_connections.py b/surfsense_backend/app/schemas/model_connections.py index 896532d6f..67d94f821 100644 --- a/surfsense_backend/app/schemas/model_connections.py +++ b/surfsense_backend/app/schemas/model_connections.py @@ -85,6 +85,10 @@ class ConnectionCreate(BaseModel): models: list[ModelSelection] = Field(default_factory=list) +class ModelTestPreview(ConnectionCreate): + model_id: str = Field(..., max_length=255) + + class ConnectionUpdate(BaseModel): provider: str | None = Field(None, max_length=100) base_url: str | None = Field(None, max_length=500) diff --git a/surfsense_backend/app/services/model_connection_service.py b/surfsense_backend/app/services/model_connection_service.py index c9ee2779f..fbfdd437f 100644 --- a/surfsense_backend/app/services/model_connection_service.py +++ b/surfsense_backend/app/services/model_connection_service.py @@ -15,7 +15,7 @@ import litellm from app.db import Connection, Model, ModelSource from app.services.model_resolver import ensure_v1, to_litellm from app.services.openrouter_model_normalizer import normalize_openrouter_models -from app.services.provider_registry import Transport, spec_for +from app.services.provider_registry import Transport, provider_label, spec_for logger = logging.getLogger(__name__) @@ -77,6 +77,68 @@ def _docker_hint(url: str | None, exc_or_status: Any) -> str: return raw +def _model_test_error(conn: Connection, model_id: str, exc: Exception) -> VerifyResult: + provider_name = provider_label(conn.provider) + raw = str(exc) + normalized = raw.lower() + exc_name = exc.__class__.__name__.lower() + status_code = getattr(exc, "status_code", None) + + logger.info( + "Model test failed for provider=%s model=%s: %s", + conn.provider, + model_id, + raw, + ) + + if status_code in (401, 403) or "authentication" in exc_name or "401" in normalized: + return VerifyResult( + "AUTH_FAILED", + False, + f"Authentication failed. Check your {provider_name} credentials and try again.", + ) + + if status_code == 404 or "notfound" in exc_name or "not found" in normalized: + if conn.provider == "azure": + message = ( + "Azure OpenAI deployment was not found. Check the deployment name, " + "API version, and endpoint." + ) + else: + message = f"Model '{model_id}' was not found on {provider_name}." + return VerifyResult("NOT_FOUND", False, message) + + if status_code == 429 or "ratelimit" in exc_name or "rate limit" in normalized: + return VerifyResult( + "RATE_LIMITED", + False, + f"{provider_name} rate limited the model test. Try again later.", + ) + + if "timeout" in exc_name or "timed out" in normalized: + return VerifyResult( + "TIMEOUT", + False, + f"{provider_name} did not respond in time. Check the endpoint and try again.", + ) + + if "connection" in exc_name or "connect" in normalized: + return VerifyResult( + "UNREACHABLE", + False, + _docker_hint( + _base_url_or_default(conn), + f"Could not reach {provider_name}. Check the endpoint and try again.", + ), + ) + + return VerifyResult( + "UNREACHABLE", + False, + f"Could not test model '{model_id}' on {provider_name}. Check the credentials, endpoint, and model name.", + ) + + async def verify_connection(conn: Connection) -> VerifyResult: spec = spec_for(conn.provider) base_url = _base_url_or_default(conn) @@ -321,15 +383,24 @@ async def _discover_bedrock_models(conn: Connection) -> list[dict[str, Any]]: return [] def list_models() -> list[dict[str, Any]]: + import os + import boto3 - client_kwargs: dict[str, str] = {"region_name": region_name} - if params.get("aws_access_key_id"): - client_kwargs["aws_access_key_id"] = params["aws_access_key_id"] - if params.get("aws_secret_access_key"): - client_kwargs["aws_secret_access_key"] = params["aws_secret_access_key"] + if bearer_token := params.get("aws_bearer_token_bedrock"): + try: + os.environ["AWS_BEARER_TOKEN_BEDROCK"] = bearer_token + client = boto3.client("bedrock", region_name=region_name) + finally: + os.environ.pop("AWS_BEARER_TOKEN_BEDROCK", None) + else: + client_kwargs: dict[str, str] = {"region_name": region_name} + if params.get("aws_access_key_id"): + client_kwargs["aws_access_key_id"] = params["aws_access_key_id"] + if params.get("aws_secret_access_key"): + client_kwargs["aws_secret_access_key"] = params["aws_secret_access_key"] + client = boto3.client("bedrock", **client_kwargs) - client = boto3.client("bedrock", **client_kwargs) response = client.list_foundation_models() results: list[dict[str, Any]] = [] for item in response.get("modelSummaries", []): @@ -393,7 +464,7 @@ async def test_model(conn: Connection, model: Model) -> VerifyResult: **kwargs, ) except Exception as exc: - return VerifyResult("UNREACHABLE", False, str(exc)) + return _model_test_error(conn, model.model_id, exc) model.supports_chat = True return VerifyResult("OK", True, "Model test succeeded.") diff --git a/surfsense_backend/app/services/model_resolver.py b/surfsense_backend/app/services/model_resolver.py index ae6fd2877..599762824 100644 --- a/surfsense_backend/app/services/model_resolver.py +++ b/surfsense_backend/app/services/model_resolver.py @@ -55,6 +55,8 @@ def to_litellm( kwargs["api_version"] = api_version kwargs.update(extra.get("litellm_params", {})) kwargs.update(extra.get("kwargs", {})) + if provider == "bedrock" and (bearer_token := kwargs.pop("aws_bearer_token_bedrock", None)): + kwargs["api_key"] = bearer_token return model_string, kwargs diff --git a/surfsense_backend/app/services/provider_registry.py b/surfsense_backend/app/services/provider_registry.py index 98bfb63c1..2a58a3468 100644 --- a/surfsense_backend/app/services/provider_registry.py +++ b/surfsense_backend/app/services/provider_registry.py @@ -38,21 +38,24 @@ class ProviderSpec: default_base_url: str | None base_url_required: bool auth_style: AuthStyle + display_name: str | None = None REGISTRY: dict[str, ProviderSpec] = { "openai": ProviderSpec( - Transport.NATIVE, "openai", "openai_models", None, False, "bearer" + Transport.NATIVE, "openai", "openai_models", None, False, "bearer", "OpenAI" ), "anthropic": ProviderSpec( - Transport.NATIVE, "anthropic", "anthropic_models", None, False, "x-api-key" + Transport.NATIVE, "anthropic", "anthropic_models", None, False, "x-api-key", "Anthropic" + ), + "azure": ProviderSpec( + Transport.NATIVE, "azure", "static", None, True, "native", "Azure OpenAI" ), - "azure": ProviderSpec(Transport.NATIVE, "azure", "static", None, True, "native"), "vertex_ai": ProviderSpec( - Transport.NATIVE, "vertex_ai", "static", None, False, "native" + Transport.NATIVE, "vertex_ai", "static", None, False, "native", "Vertex AI" ), "bedrock": ProviderSpec( - Transport.NATIVE, "bedrock", "bedrock_models", None, False, "native" + Transport.NATIVE, "bedrock", "bedrock_models", None, False, "native", "Amazon Bedrock" ), "openrouter": ProviderSpec( Transport.OPENAI_COMPATIBLE, @@ -61,6 +64,7 @@ REGISTRY: dict[str, ProviderSpec] = { "https://openrouter.ai/api/v1", False, "bearer", + "OpenRouter", ), "openai_compatible": ProviderSpec( Transport.OPENAI_COMPATIBLE, @@ -69,6 +73,7 @@ REGISTRY: dict[str, ProviderSpec] = { None, True, "bearer", + "OpenAI-compatible provider", ), "lm_studio": ProviderSpec( Transport.OPENAI_COMPATIBLE, @@ -77,6 +82,7 @@ REGISTRY: dict[str, ProviderSpec] = { "http://localhost:1234/v1", True, "bearer", + "LM Studio", ), "ollama_chat": ProviderSpec( Transport.OLLAMA, @@ -85,6 +91,7 @@ REGISTRY: dict[str, ProviderSpec] = { "http://localhost:11434", True, "none", + "Ollama", ), } @@ -96,4 +103,12 @@ def spec_for(provider: str | None) -> ProviderSpec: ) -__all__ = ["REGISTRY", "ProviderSpec", "Transport", "spec_for"] +def provider_label(provider: str | None) -> str: + provider_key = (provider or "").strip() + spec = spec_for(provider_key) + if spec.display_name: + return spec.display_name + return provider_key.replace("_", " ").title() if provider_key else "Provider" + + +__all__ = ["REGISTRY", "ProviderSpec", "Transport", "provider_label", "spec_for"] diff --git a/surfsense_web/atoms/model-connections/model-connections-mutation.atoms.ts b/surfsense_web/atoms/model-connections/model-connections-mutation.atoms.ts index ea91c6483..00f8fa9ad 100644 --- a/surfsense_web/atoms/model-connections/model-connections-mutation.atoms.ts +++ b/surfsense_web/atoms/model-connections/model-connections-mutation.atoms.ts @@ -7,6 +7,7 @@ import type { ModelPreviewRead, ModelRead, ModelRoles, + ModelTestPreviewRequest, ModelsBulkUpdateRequest, ModelUpdateRequest, VerifyConnectionResponse, @@ -114,6 +115,20 @@ export const previewConnectionModelsMutationAtom = atomWithMutation(() => { }; }); +export const testPreviewModelMutationAtom = atomWithMutation(() => { + return { + mutationKey: ["model-connections", "test-preview"], + mutationFn: (request: ModelTestPreviewRequest) => + modelConnectionsApiService.testPreviewModel(request), + onSuccess: (result: VerifyConnectionResponse) => { + if (!result.ok) { + toast.error(result.message || "Model test failed"); + } + }, + onError: (error: Error) => toast.error(error.message || "Failed to test model"), + }; +}); + export const addManualModelMutationAtom = atomWithMutation((get) => { const searchSpaceId = Number(get(activeSearchSpaceIdAtom)); return { diff --git a/surfsense_web/components/settings/model-connections-settings.tsx b/surfsense_web/components/settings/model-connections-settings.tsx index 6c3d1a411..cf00ac6c9 100644 --- a/surfsense_web/components/settings/model-connections-settings.tsx +++ b/surfsense_web/components/settings/model-connections-settings.tsx @@ -1,12 +1,14 @@ "use client"; import { useAtom, useAtomValue } from "jotai"; -import { CheckCircle2, Trash2, XCircle } from "lucide-react"; +import { Trash2 } from "lucide-react"; import { useState } from "react"; +import { toast } from "sonner"; import { createModelConnectionMutationAtom, deleteModelConnectionMutationAtom, previewConnectionModelsMutationAtom, + testPreviewModelMutationAtom, updateModelRolesMutationAtom, } from "@/atoms/model-connections/model-connections-mutation.atoms"; import { @@ -53,24 +55,6 @@ import { providerIcon, } from "./model-connections/provider-metadata"; -function StatusBadge({ connection }: { connection: ConnectionRead }) { - if (connection.last_status === "OK") { - return ( - - Healthy - - ); - } - if (connection.last_status) { - return ( - - {connection.last_status} - - ); - } - return Not tested; -} - function flattenModels(connections: ConnectionRead[]) { return connections.flatMap((connection) => connection.models.map((model) => ({ @@ -110,7 +94,6 @@ function ConnectionCard({ connection }: { connection: ConnectionRead }) {
- @@ -156,6 +139,7 @@ export function ModelConnectionsSettings({ searchSpaceId }: { searchSpaceId: num const [{ data: roles }] = useAtom(modelRolesAtom); const createConnection = useAtomValue(createModelConnectionMutationAtom); const previewModels = useAtomValue(previewConnectionModelsMutationAtom); + const testPreviewModel = useAtomValue(testPreviewModelMutationAtom); const updateRoles = useAtomValue(updateModelRolesMutationAtom); const [isAddProviderOpen, setIsAddProviderOpen] = useState(false); @@ -220,9 +204,7 @@ export function ModelConnectionsSettings({ searchSpaceId }: { searchSpaceId: num }); } - // Each provider connect form builds its own credential payload; the backend - // resolver (`to_litellm`) forwards `extra.litellm_params` straight to LiteLLM. - function handleCreate(draft: ConnectionDraft) { + function connectionModelsForDraft(draft: ConnectionDraft) { const models = [...connectModels]; if (draft.seedModelId && !models.some((model) => model.model_id === draft.seedModelId)) { models.push({ @@ -233,22 +215,46 @@ export function ModelConnectionsSettings({ searchSpaceId }: { searchSpaceId: num metadata: {}, }); } + return models; + } - createConnection.mutate( + function representativeTestModel(models: ModelSelection[]) { + const enabledModels = models.filter((model) => model.enabled); + return enabledModels.find((model) => capability(model, "chat")) ?? enabledModels[0]; + } + + // Each provider connect form builds its own credential payload; the backend + // resolver (`to_litellm`) forwards `extra.litellm_params` straight to LiteLLM. + function handleCreate(draft: ConnectionDraft) { + const models = connectionModelsForDraft(draft); + const testModel = representativeTestModel(models); + if (!testModel) { + toast.error("Select at least one model before connecting"); + return; + } + + const request = { + provider, + base_url: draft.base_url, + api_key: draft.api_key, + scope: "SEARCH_SPACE" as const, + search_space_id: searchSpaceId, + extra: draft.extra, + enabled: true, + models, + }; + + testPreviewModel.mutate( + { ...request, model_id: testModel.model_id }, { - provider, - base_url: draft.base_url, - api_key: draft.api_key, - scope: "SEARCH_SPACE", - search_space_id: searchSpaceId, - extra: draft.extra, - enabled: true, - models, - }, - { - onSuccess: () => { - setIsAddProviderOpen(false); - resetConnectState(); + onSuccess: (result) => { + if (!result.ok) return; + createConnection.mutate(request, { + onSuccess: () => { + setIsAddProviderOpen(false); + resetConnectState(); + }, + }); }, } ); @@ -380,7 +386,7 @@ export function ModelConnectionsSettings({ searchSpaceId }: { searchSpaceId: num onOpenChange={handleConnectOpenChange} provider={provider} selectedProvider={selectedProvider} - isPending={createConnection.isPending} + isPending={createConnection.isPending || testPreviewModel.isPending} onSubmit={handleCreate} previewModels={connectModels} isPreviewingModels={previewModels.isPending} diff --git a/surfsense_web/components/settings/model-connections/connection-settings-dialog.tsx b/surfsense_web/components/settings/model-connections/connection-settings-dialog.tsx index d0f8e6c16..badddb8d7 100644 --- a/surfsense_web/components/settings/model-connections/connection-settings-dialog.tsx +++ b/surfsense_web/components/settings/model-connections/connection-settings-dialog.tsx @@ -5,9 +5,9 @@ import { addManualModelMutationAtom, bulkUpdateModelsMutationAtom, discoverConnectionModelsMutationAtom, + testPreviewModelMutationAtom, updateModelConnectionMutationAtom, updateModelMutationAtom, - verifyModelConnectionMutationAtom, } from "@/atoms/model-connections/model-connections-mutation.atoms"; import { Button } from "@/components/ui/button"; import { @@ -26,7 +26,7 @@ import type { ConnectionRead, ConnectionUpdateRequest, } from "@/contracts/types/model-connections.types"; -import type { SelectableModel } from "./model-utils"; +import { capability, type SelectableModel } from "./model-utils"; import { ModelsSelectionPanel } from "./models-selection-panel"; import { providerIcon } from "./provider-metadata"; @@ -39,8 +39,8 @@ export function ConnectionSettingsDialog({ connection, providerLabel, }: ConnectionSettingsDialogProps) { - const verifyConnection = useAtomValue(verifyModelConnectionMutationAtom); const discoverModels = useAtomValue(discoverConnectionModelsMutationAtom); + const testPreviewModel = useAtomValue(testPreviewModelMutationAtom); const updateConnection = useAtomValue(updateModelConnectionMutationAtom); const addManualModel = useAtomValue(addManualModelMutationAtom); const updateModel = useAtomValue(updateModelMutationAtom); @@ -81,11 +81,45 @@ export function ConnectionSettingsDialog({ if (apiKeyDraft.trim() !== (connection.api_key ?? "")) { data.api_key = apiKeyDraft.trim() || null; } + const apiKeyForTest = Object.hasOwn(data, "api_key") + ? (data.api_key ?? null) + : (connection.api_key ?? null); - updateConnection.mutate( - { id: connection.id, data }, + const enabledModels = connection.models.filter((model) => model.enabled); + const testModel = + enabledModels.find((model) => capability(model, "chat")) ?? enabledModels[0]; + if (!testModel) { + updateConnection.mutate( + { id: connection.id, data }, + { + onSuccess: () => setApiKeyDraft(""), + } + ); + return; + } + + testPreviewModel.mutate( { - onSuccess: () => setApiKeyDraft(""), + provider: connection.provider, + base_url: data.base_url, + api_key: apiKeyForTest, + scope: "SEARCH_SPACE", + search_space_id: connection.search_space_id, + extra: connection.extra ?? {}, + enabled: connection.enabled, + models: [], + model_id: testModel.model_id, + }, + { + onSuccess: (result) => { + if (!result.ok) return; + updateConnection.mutate( + { id: connection.id, data }, + { + onSuccess: () => setApiKeyDraft(""), + } + ); + }, } ); } @@ -219,26 +253,15 @@ export function ConnectionSettingsDialog({ onBulkToggle={handleBulkToggle} /> - {connection.last_status && connection.last_status !== "OK" ? ( -

- {connection.last_error || "Could not list models."} Chat may still work; add model - IDs manually if discovery is unavailable. -

- ) : null}
- diff --git a/surfsense_web/components/settings/model-connections/provider-connect-dialog.tsx b/surfsense_web/components/settings/model-connections/provider-connect-dialog.tsx index 315b9d3fa..51263d5f5 100644 --- a/surfsense_web/components/settings/model-connections/provider-connect-dialog.tsx +++ b/surfsense_web/components/settings/model-connections/provider-connect-dialog.tsx @@ -94,6 +94,8 @@ export function ProviderConnectDialog({ })(); const canRefreshModels = !isAzure && !isVertex && (!isBedrock || canSubmit); + const hasEnabledModel = previewModels.some((model) => model.enabled) || Boolean(currentDraft.seedModelId); + const canConnect = canSubmit && hasEnabledModel; return ( @@ -134,7 +136,7 @@ export function ProviderConnectDialog({ onOpenChange(false)} onSubmit={() => onSubmit(currentDraft)} - canSubmit={canSubmit} + canSubmit={canConnect} isPending={isPending} /> diff --git a/surfsense_web/contracts/types/model-connections.types.ts b/surfsense_web/contracts/types/model-connections.types.ts index 134c740b2..16db93868 100644 --- a/surfsense_web/contracts/types/model-connections.types.ts +++ b/surfsense_web/contracts/types/model-connections.types.ts @@ -66,6 +66,10 @@ export const connectionCreateRequest = z.object({ models: z.array(modelSelection).default([]), }); +export const modelTestPreviewRequest = connectionCreateRequest.extend({ + model_id: z.string().min(1), +}); + export const connectionUpdateRequest = z.object({ provider: z.string().nullable().optional(), base_url: z.string().nullable().optional(), @@ -129,6 +133,7 @@ export type ModelPreviewRead = z.infer; export type ModelSelection = z.infer; export type ConnectionRead = z.infer; export type ConnectionCreateRequest = z.infer; +export type ModelTestPreviewRequest = z.infer; export type ConnectionUpdateRequest = z.infer; export type ModelCreateRequest = z.infer; export type ModelUpdateRequest = z.infer; diff --git a/surfsense_web/lib/apis/model-connections-api.service.ts b/surfsense_web/lib/apis/model-connections-api.service.ts index f463a27e7..d875255ad 100644 --- a/surfsense_web/lib/apis/model-connections-api.service.ts +++ b/surfsense_web/lib/apis/model-connections-api.service.ts @@ -11,6 +11,7 @@ import { type ModelProviderRead, type ModelRead, type ModelRoles, + type ModelTestPreviewRequest, type ModelsBulkUpdateRequest, type ModelUpdateRequest, modelCreateRequest, @@ -19,6 +20,7 @@ import { modelProviderListResponse, modelRead, modelRoles, + modelTestPreviewRequest, modelsBulkUpdateRequest, modelUpdateRequest, type VerifyConnectionResponse, @@ -92,6 +94,20 @@ class ModelConnectionsApiService { ); }; + testPreviewModel = async (request: ModelTestPreviewRequest): Promise => { + const parsed = modelTestPreviewRequest.safeParse(request); + if (!parsed.success) { + throw new ValidationError(parsed.error.issues.map((issue) => issue.message).join(", ")); + } + return baseApiService.post( + `/api/v1/model-connections/test-preview`, + verifyConnectionResponse, + { + body: parsed.data, + } + ); + }; + addManualModel = async ( connectionId: number, request: ModelCreateRequest From e77b0c5d4eaf0fd83493fbe273d619aa23e7f79e Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Sat, 13 Jun 2026 00:25:53 +0530 Subject: [PATCH 34/59] feat(icons): add Azure, Claude, and LM Studio icons; update Bedrock icon and provider metadata --- .../components/icons/providers/azure.svg | 1 + .../components/icons/providers/bedrock.svg | 2 +- .../components/icons/providers/claude.svg | 1 + .../components/icons/providers/index.ts | 3 +++ .../components/icons/providers/lm-studio.svg | 21 +++++++++++++++++++ .../components/icons/providers/vertexai.svg | 2 +- .../model-connections/provider-metadata.tsx | 6 +++--- surfsense_web/lib/provider-icons.tsx | 14 +++++++++---- 8 files changed, 41 insertions(+), 9 deletions(-) create mode 100644 surfsense_web/components/icons/providers/azure.svg create mode 100644 surfsense_web/components/icons/providers/claude.svg create mode 100644 surfsense_web/components/icons/providers/lm-studio.svg diff --git a/surfsense_web/components/icons/providers/azure.svg b/surfsense_web/components/icons/providers/azure.svg new file mode 100644 index 000000000..ba80f55ca --- /dev/null +++ b/surfsense_web/components/icons/providers/azure.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/surfsense_web/components/icons/providers/bedrock.svg b/surfsense_web/components/icons/providers/bedrock.svg index 195aa6594..cde500c0d 100644 --- a/surfsense_web/components/icons/providers/bedrock.svg +++ b/surfsense_web/components/icons/providers/bedrock.svg @@ -1 +1 @@ - \ No newline at end of file + \ No newline at end of file diff --git a/surfsense_web/components/icons/providers/claude.svg b/surfsense_web/components/icons/providers/claude.svg new file mode 100644 index 000000000..8d732d5b0 --- /dev/null +++ b/surfsense_web/components/icons/providers/claude.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/surfsense_web/components/icons/providers/index.ts b/surfsense_web/components/icons/providers/index.ts index aefa2a053..8275595b1 100644 --- a/surfsense_web/components/icons/providers/index.ts +++ b/surfsense_web/components/icons/providers/index.ts @@ -1,8 +1,10 @@ export { default as Ai21Icon } from "./ai21.svg"; export { default as AnthropicIcon } from "./anthropic.svg"; export { default as AnyscaleIcon } from "./anyscale.svg"; +export { default as AzureIcon } from "./azure.svg"; export { default as BedrockIcon } from "./bedrock.svg"; export { default as CerebrasIcon } from "./cerebras.svg"; +export { default as ClaudeIcon } from "./claude.svg"; export { default as CohereIcon } from "./cohere.svg"; export { default as CometApiIcon } from "./cometapi.svg"; export { default as DatabricksIcon } from "./dbrx.svg"; @@ -15,6 +17,7 @@ export { default as GroqIcon } from "./groq.svg"; export { default as HuggingFaceIcon } from "./huggingface.svg"; export { default as MiniMaxIcon } from "./minimax.svg"; export { default as MistralIcon } from "./mistral.svg"; +export { default as LmStudioIcon } from "./lm-studio.svg"; export { default as MoonshotIcon } from "./moonshot.svg"; export { default as NscaleIcon } from "./nscale.svg"; export { default as OllamaIcon } from "./ollama.svg"; diff --git a/surfsense_web/components/icons/providers/lm-studio.svg b/surfsense_web/components/icons/providers/lm-studio.svg new file mode 100644 index 000000000..b6ae7db3e --- /dev/null +++ b/surfsense_web/components/icons/providers/lm-studio.svg @@ -0,0 +1,21 @@ + + + + + + + + + + + + + + + + + + + + + diff --git a/surfsense_web/components/icons/providers/vertexai.svg b/surfsense_web/components/icons/providers/vertexai.svg index 45adce83b..e46a3ca0f 100644 --- a/surfsense_web/components/icons/providers/vertexai.svg +++ b/surfsense_web/components/icons/providers/vertexai.svg @@ -1 +1 @@ - \ No newline at end of file + \ No newline at end of file diff --git a/surfsense_web/components/settings/model-connections/provider-metadata.tsx b/surfsense_web/components/settings/model-connections/provider-metadata.tsx index 73e873393..8b8a877b9 100644 --- a/surfsense_web/components/settings/model-connections/provider-metadata.tsx +++ b/surfsense_web/components/settings/model-connections/provider-metadata.tsx @@ -19,12 +19,12 @@ export const PROVIDER_DISPLAY: Record< anthropic: { name: "Claude", subtitle: "Anthropic", - iconKey: "anthropic", + iconKey: "claude", defaultBaseUrl: "https://api.anthropic.com/v1", }, - azure: { name: "Azure OpenAI", subtitle: "Microsoft Azure", iconKey: "azure_openai" }, + azure: { name: "Azure OpenAI", subtitle: "Microsoft Azure", iconKey: "azure" }, bedrock: { name: "Amazon Bedrock", subtitle: "AWS", iconKey: "bedrock" }, - lm_studio: { name: "LM Studio", subtitle: "LM Studio", iconKey: "custom" }, + lm_studio: { name: "LM Studio", subtitle: "LM Studio", iconKey: "lm_studio" }, ollama_chat: { name: "Ollama", subtitle: "Ollama", iconKey: "ollama" }, openai: { name: "GPT", diff --git a/surfsense_web/lib/provider-icons.tsx b/surfsense_web/lib/provider-icons.tsx index 3bb310904..4b2a4dfbe 100644 --- a/surfsense_web/lib/provider-icons.tsx +++ b/surfsense_web/lib/provider-icons.tsx @@ -1,10 +1,11 @@ import { Cpu, Shuffle } from "lucide-react"; import { Ai21Icon, - AnthropicIcon, AnyscaleIcon, + AzureIcon, BedrockIcon, CerebrasIcon, + ClaudeIcon, CloudflareIcon, CohereIcon, CometApiIcon, @@ -16,6 +17,7 @@ import { GitHubModelsIcon, GroqIcon, HuggingFaceIcon, + LmStudioIcon, MiniMaxIcon, MistralIcon, MoonshotIcon, @@ -54,12 +56,13 @@ export function getProviderIcon( case "ALIBABA_QWEN": return ; case "ANTHROPIC": - return ; + case "CLAUDE": + return ; case "ANYSCALE": return ; case "AZURE": case "AZURE_OPENAI": - return ; + return ; case "AWS_BEDROCK": case "BEDROCK": return ; @@ -72,7 +75,7 @@ export function getProviderIcon( case "COMETAPI": return ; case "CUSTOM": - return ; + return ; case "DATABRICKS": return ; case "DEEPINFRA": @@ -89,6 +92,8 @@ export function getProviderIcon( return ; case "HUGGINGFACE": return ; + case "LM_STUDIO": + return ; case "MINIMAX": return ; case "MISTRAL": @@ -98,6 +103,7 @@ export function getProviderIcon( case "NSCALE": return ; case "OLLAMA": + case "OLLAMA_CHAT": return ; case "OPENAI": return ; From 7a1bb2acd6fdce526aebe41d99a916ccad75c689 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Sat, 13 Jun 2026 00:46:53 +0530 Subject: [PATCH 35/59] feat(model-connections): refactor model roles UI --- .../settings/model-connections-settings.tsx | 132 +++++++++--------- .../connection-settings-dialog.tsx | 7 +- 2 files changed, 75 insertions(+), 64 deletions(-) diff --git a/surfsense_web/components/settings/model-connections-settings.tsx b/surfsense_web/components/settings/model-connections-settings.tsx index cf00ac6c9..c61368974 100644 --- a/surfsense_web/components/settings/model-connections-settings.tsx +++ b/surfsense_web/components/settings/model-connections-settings.tsx @@ -1,7 +1,7 @@ "use client"; import { useAtom, useAtomValue } from "jotai"; -import { Trash2 } from "lucide-react"; +import { Dot, Trash2 } from "lucide-react"; import { useState } from "react"; import { toast } from "sonner"; import { @@ -30,7 +30,6 @@ import { } from "@/components/ui/alert-dialog"; import { Badge } from "@/components/ui/badge"; import { Button } from "@/components/ui/button"; -import { Card, CardContent, CardDescription, CardHeader, CardTitle } from "@/components/ui/card"; import { Label } from "@/components/ui/label"; import { Select, @@ -78,7 +77,7 @@ function ConnectionCard({ connection }: { connection: ConnectionRead }) { return (
-
+
{providerIcon(connection.provider)} @@ -100,10 +99,11 @@ function ConnectionCard({ connection }: { connection: ConnectionRead }) { @@ -335,7 +335,11 @@ export function ModelConnectionsSettings({ searchSpaceId }: { searchSpaceId: num {providerIcon(model.provider)} - {modelLabel(model)} · {model.connectionName} + + {modelLabel(model)} + ); @@ -343,7 +347,66 @@ export function ModelConnectionsSettings({ searchSpaceId }: { searchSpaceId: num return (
-
+
+
+

Model Roles

+

+ Pick which enabled model powers chat, vision, and image generation for this search + space. +

+
+
+
+ + +
+
+ + +
+
+ + +
+
+
+ + + +

Add Provider

@@ -408,63 +471,6 @@ export function ModelConnectionsSettings({ searchSpaceId }: { searchSpaceId: num
) : null}
- - - - Model Roles - - Pick which enabled model powers chat, vision, and image generation for this search - space. - - - -
- - -
-
- - -
-
- - -
-
-
); } diff --git a/surfsense_web/components/settings/model-connections/connection-settings-dialog.tsx b/surfsense_web/components/settings/model-connections/connection-settings-dialog.tsx index badddb8d7..d20dbbdc6 100644 --- a/surfsense_web/components/settings/model-connections/connection-settings-dialog.tsx +++ b/surfsense_web/components/settings/model-connections/connection-settings-dialog.tsx @@ -157,7 +157,12 @@ export function ConnectionSettingsDialog({ return ( - From 45d27ba87992b6ec9c07a0df3d24357e7f830d2e Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Sat, 13 Jun 2026 01:37:12 +0530 Subject: [PATCH 36/59] feat(model-connections): enhance auto mode with auto pinning --- .../versions/160_add_model_connections.py | 39 ++++- surfsense_backend/app/db.py | 6 +- .../app/routes/model_connections_routes.py | 140 +++++++++++++++++- .../settings/model-connections-settings.tsx | 30 +++- 4 files changed, 196 insertions(+), 19 deletions(-) diff --git a/surfsense_backend/alembic/versions/160_add_model_connections.py b/surfsense_backend/alembic/versions/160_add_model_connections.py index 49d6315ca..2c35bd568 100644 --- a/surfsense_backend/alembic/versions/160_add_model_connections.py +++ b/surfsense_backend/alembic/versions/160_add_model_connections.py @@ -61,9 +61,21 @@ def _create_index_if_missing( op.create_index(index_name, table_name, columns, unique=False) -def _add_searchspace_column_if_missing(column_name: str) -> None: +def _add_searchspace_column_if_missing( + column_name: str, + *, + server_default: object | None = None, +) -> None: if not _column_exists("searchspaces", column_name): - op.add_column("searchspaces", sa.Column(column_name, sa.Integer(), nullable=True)) + op.add_column( + "searchspaces", + sa.Column( + column_name, + sa.Integer(), + nullable=True, + server_default=server_default, + ), + ) def _drop_column_if_exists(table_name: str, column_name: str) -> None: @@ -233,9 +245,26 @@ def upgrade() -> None: _create_index_if_missing("ix_models_model_id", "models", ["model_id"]) _create_index_if_missing("ix_models_billing_tier", "models", ["billing_tier"]) - _add_searchspace_column_if_missing("chat_model_id") - _add_searchspace_column_if_missing("image_gen_model_id") - _add_searchspace_column_if_missing("vision_model_id") + _add_searchspace_column_if_missing("chat_model_id", server_default=sa.text("0")) + _add_searchspace_column_if_missing("image_gen_model_id", server_default=sa.text("0")) + _add_searchspace_column_if_missing("vision_model_id", server_default=sa.text("0")) + for column_name in ("chat_model_id", "image_gen_model_id", "vision_model_id"): + op.alter_column( + "searchspaces", + column_name, + existing_type=sa.Integer(), + existing_nullable=True, + server_default=sa.text("0"), + ) + op.execute( + """ + UPDATE searchspaces + SET + chat_model_id = COALESCE(chat_model_id, 0), + image_gen_model_id = COALESCE(image_gen_model_id, 0), + vision_model_id = COALESCE(vision_model_id, 0) + """ + ) op.execute("DROP TYPE IF EXISTS connectionprotocol") diff --git a/surfsense_backend/app/db.py b/surfsense_backend/app/db.py index eeed9932d..728031fa0 100644 --- a/surfsense_backend/app/db.py +++ b/surfsense_backend/app/db.py @@ -1853,13 +1853,13 @@ class SearchSpace(BaseModel, TimestampMixin): # - Negative IDs: Global virtual models from global_llm_config.yaml # - Positive IDs: User/search-space models from the models table chat_model_id = Column( - Integer, nullable=True, default=0 + Integer, nullable=True, default=0, server_default="0" ) # For agent/chat operations, defaults to Auto mode image_gen_model_id = Column( - Integer, nullable=True, default=0 + Integer, nullable=True, default=0, server_default="0" ) # For image generation, defaults to Auto mode when eligible vision_model_id = Column( - Integer, nullable=True, default=0 + Integer, nullable=True, default=0, server_default="0" ) # For vision/screenshot analysis, defaults to Auto mode ai_file_sort_enabled = Column( diff --git a/surfsense_backend/app/routes/model_connections_routes.py b/surfsense_backend/app/routes/model_connections_routes.py index 76e4a3dfb..6db9aa9f3 100644 --- a/surfsense_backend/app/routes/model_connections_routes.py +++ b/surfsense_backend/app/routes/model_connections_routes.py @@ -131,6 +131,95 @@ def _default_model_for(models: list[Model], capability: str) -> int | None: return None +async def _load_role_model( + session: AsyncSession, + search_space_id: int, + model_id: int, +) -> Model | dict | None: + if model_id < 0: + return next( + (model for model in config.GLOBAL_MODELS if model.get("id") == model_id), + None, + ) + + result = await session.execute( + select(Model) + .options(selectinload(Model.connection)) + .where(Model.id == model_id) + ) + model = result.scalars().first() + if model is None or model.connection.search_space_id != search_space_id: + return None + return model + + +def _role_model_enabled(model: Model | dict) -> bool: + if isinstance(model, dict): + return bool(model.get("enabled", True)) + return bool(model.enabled and model.connection.enabled) + + +async def _validate_role_model_id( + session: AsyncSession, + *, + search_space_id: int, + model_id: int | None, + capability: str, +) -> int: + if model_id is None or model_id == 0: + return 0 + + model = await _load_role_model(session, search_space_id, model_id) + if model and _role_model_enabled(model) and has_capability(model, capability): + return model_id + + raise HTTPException( + status_code=400, + detail=f"Selected model is not available for {capability}", + ) + + +async def _resolve_role_model_id( + session: AsyncSession, + *, + search_space_id: int, + model_id: int | None, + capability: str, +) -> int: + try: + return await _validate_role_model_id( + session, + search_space_id=search_space_id, + model_id=model_id, + capability=capability, + ) + except HTTPException: + return 0 + + +async def _clear_invalid_roles(session: AsyncSession, search_space_id: int) -> SearchSpace: + search_space = await _get_search_space(session, search_space_id) + search_space.chat_model_id = await _resolve_role_model_id( + session, + search_space_id=search_space_id, + model_id=search_space.chat_model_id, + capability="chat", + ) + search_space.vision_model_id = await _resolve_role_model_id( + session, + search_space_id=search_space_id, + model_id=search_space.vision_model_id, + capability="vision", + ) + search_space.image_gen_model_id = await _resolve_role_model_id( + session, + search_space_id=search_space_id, + model_id=search_space.image_gen_model_id, + capability="image_gen", + ) + return search_space + + async def _default_unset_roles( session: AsyncSession, conn: Connection, @@ -372,9 +461,13 @@ async def update_connection( ): conn = await _load_connection(session, connection_id) await _assert_connection_access(session, user, conn, Permission.LLM_CONFIGS_UPDATE.value) + search_space_id = conn.search_space_id for key, value in data.model_dump(exclude_unset=True).items(): setattr(conn, key, value) await session.commit() + if search_space_id is not None: + await _clear_invalid_roles(session, search_space_id) + await session.commit() conn = await _load_connection(session, connection_id) return _connection_read(conn, list(conn.models)) @@ -387,8 +480,12 @@ async def delete_connection( ): conn = await _load_connection(session, connection_id) await _assert_connection_access(session, user, conn, Permission.LLM_CONFIGS_DELETE.value) + search_space_id = conn.search_space_id await session.delete(conn) await session.commit() + if search_space_id is not None: + await _clear_invalid_roles(session, search_space_id) + await session.commit() return {"status": "deleted"} @@ -439,6 +536,8 @@ async def discover_connection_models( await session.commit() conn = await _load_connection(session, connection_id) await _default_unset_roles(session, conn, list(conn.models)) + if conn.search_space_id is not None: + await _clear_invalid_roles(session, conn.search_space_id) await session.commit() conn = await _load_connection(session, connection_id) return [_model_read(model) for model in conn.models] @@ -476,7 +575,10 @@ async def add_manual_model( await session.refresh(model) conn = await _load_connection(session, connection_id) await _default_unset_roles(session, conn, list(conn.models)) + if conn.search_space_id is not None: + await _clear_invalid_roles(session, conn.search_space_id) await session.commit() + await session.refresh(model) return _model_read(model) @@ -489,6 +591,7 @@ async def bulk_update_models( ): conn = await _load_connection(session, connection_id) await _assert_connection_access(session, user, conn, Permission.LLM_CONFIGS_UPDATE.value) + search_space_id = conn.search_space_id model_ids = set(data.model_ids) await session.execute( @@ -498,6 +601,10 @@ async def bulk_update_models( ) await session.commit() session.expire_all() + if search_space_id is not None: + await _clear_invalid_roles(session, search_space_id) + await session.commit() + session.expire_all() result = await session.execute( select(Model) @@ -521,11 +628,16 @@ async def update_model( if not model: raise HTTPException(status_code=404, detail="Model not found") await _assert_connection_access(session, user, model.connection, Permission.LLM_CONFIGS_UPDATE.value) + search_space_id = model.connection.search_space_id update = data.model_dump(exclude_unset=True) for key, value in update.items(): setattr(model, key, value) await session.commit() await session.refresh(model) + if search_space_id is not None: + await _clear_invalid_roles(session, search_space_id) + await session.commit() + await session.refresh(model) return _model_read(model) @@ -560,7 +672,9 @@ async def get_model_roles( Permission.LLM_CONFIGS_CREATE.value, "You don't have permission to view model roles in this search space", ) - search_space = await _get_search_space(session, search_space_id) + search_space = await _clear_invalid_roles(session, search_space_id) + await session.commit() + await session.refresh(search_space) return ModelRolesRead( chat_model_id=search_space.chat_model_id, vision_model_id=search_space.vision_model_id, @@ -583,8 +697,28 @@ async def update_model_roles( "You don't have permission to update model roles in this search space", ) search_space = await _get_search_space(session, search_space_id) - for key, value in data.model_dump(exclude_unset=True).items(): - setattr(search_space, key, value) + updates = data.model_dump(exclude_unset=True) + if "chat_model_id" in updates: + search_space.chat_model_id = await _validate_role_model_id( + session, + search_space_id=search_space_id, + model_id=updates["chat_model_id"], + capability="chat", + ) + if "vision_model_id" in updates: + search_space.vision_model_id = await _validate_role_model_id( + session, + search_space_id=search_space_id, + model_id=updates["vision_model_id"], + capability="vision", + ) + if "image_gen_model_id" in updates: + search_space.image_gen_model_id = await _validate_role_model_id( + session, + search_space_id=search_space_id, + model_id=updates["image_gen_model_id"], + capability="image_gen", + ) await session.commit() await session.refresh(search_space) return ModelRolesRead( diff --git a/surfsense_web/components/settings/model-connections-settings.tsx b/surfsense_web/components/settings/model-connections-settings.tsx index c61368974..1f5c166b3 100644 --- a/surfsense_web/components/settings/model-connections-settings.tsx +++ b/surfsense_web/components/settings/model-connections-settings.tsx @@ -65,6 +65,11 @@ function flattenModels(connections: ConnectionRead[]) { ); } +function roleSelectValue(modelId: number | null | undefined, models: Array<{ id: number }>) { + if (!modelId) return "0"; + return models.some((model) => model.id === modelId) ? String(modelId) : "0"; +} + function ConnectionCard({ connection }: { connection: ConnectionRead }) { const deleteConnection = useAtomValue(deleteModelConnectionMutationAtom); @@ -349,8 +354,8 @@ export function ModelConnectionsSettings({ searchSpaceId }: { searchSpaceId: num
-

Model Roles

-

+

Model Roles

+

Pick which enabled model powers chat, vision, and image generation for this search space.

@@ -358,8 +363,12 @@ export function ModelConnectionsSettings({ searchSpaceId }: { searchSpaceId: num
+

+ Primary model for chat responses and agent tasks. You can also change it from the + chat. +

updateRoles.mutate({ vision_model_id: Number(value) })} > @@ -388,8 +401,9 @@ export function ModelConnectionsSettings({ searchSpaceId }: { searchSpaceId: num
+

Used when generating images in chat.

setSearch(event.target.value)} - placeholder="Search chat models..." - className="pl-9" + placeholder="Search chat models" + className="h-8 border-0 bg-transparent pl-6 text-sm shadow-none" />
-
+
@@ -217,17 +226,22 @@ export function ModelSelector({ type="button" variant="ghost" size="sm" - className={cn("h-8 gap-2 rounded-full px-3 text-muted-foreground", className)} + className={cn( + "h-8 min-w-0 gap-2 rounded-md px-3 text-muted-foreground transition-colors", + "hover:bg-foreground/10 hover:text-foreground", + "data-[state=open]:bg-foreground/10 data-[state=open]:text-foreground", + className + )} > {selected ? ( - getProviderIcon(selected.provider, { className: "size-4" }) + getProviderIcon(selected.provider, { className: "size-4 shrink-0" }) ) : ( - + )} - + {selected ? modelName(selected) : "Auto"} - + ); @@ -235,7 +249,8 @@ export function ModelSelector({ return ( {trigger} - + + Select Chat Model @@ -248,7 +263,7 @@ export function ModelSelector({ return ( {trigger} - + {content} From 7493ba93241a2d526f3bd8e42954b01c3879c189 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Sat, 13 Jun 2026 03:34:13 +0530 Subject: [PATCH 41/59] refactor(icons): replace Clock3 icon with AlarmClock for automations --- .../automations/components/automations-empty-state.tsx | 4 ++-- .../automations/components/automations-table.tsx | 4 ++-- .../components/layout/providers/LayoutDataProvider.tsx | 4 ++-- surfsense_web/components/new-chat/chat-example-prompts.tsx | 4 ++-- .../components/tool-ui/automation/create-automation.tsx | 6 +++--- surfsense_web/contracts/enums/toolIcons.tsx | 4 ++-- 6 files changed, 13 insertions(+), 13 deletions(-) diff --git a/surfsense_web/app/dashboard/[search_space_id]/automations/components/automations-empty-state.tsx b/surfsense_web/app/dashboard/[search_space_id]/automations/components/automations-empty-state.tsx index 70d9990f8..1ee71c636 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/automations/components/automations-empty-state.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/automations/components/automations-empty-state.tsx @@ -1,5 +1,5 @@ "use client"; -import { Clock3 } from "lucide-react"; +import { AlarmClock } from "lucide-react"; import Link from "next/link"; import { Button } from "@/components/ui/button"; @@ -18,7 +18,7 @@ export function AutomationsEmptyState({ searchSpaceId, canCreate }: AutomationsE return (
- +

No automations yet

diff --git a/surfsense_web/app/dashboard/[search_space_id]/automations/components/automations-table.tsx b/surfsense_web/app/dashboard/[search_space_id]/automations/components/automations-table.tsx index 727636b43..9ca510d44 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/automations/components/automations-table.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/automations/components/automations-table.tsx @@ -1,5 +1,5 @@ "use client"; -import { CalendarDays, Clock3, Info } from "lucide-react"; +import { CalendarDays, AlarmClock, Info } from "lucide-react"; import { Table, TableBody, TableHead, TableHeader, TableRow } from "@/components/ui/table"; import type { AutomationSummary } from "@/contracts/types/automation.types"; import { AutomationRow } from "./automation-row"; @@ -31,7 +31,7 @@ export function AutomationsTable({ - + Name diff --git a/surfsense_web/components/layout/providers/LayoutDataProvider.tsx b/surfsense_web/components/layout/providers/LayoutDataProvider.tsx index 5c62f6a7d..d2754594a 100644 --- a/surfsense_web/components/layout/providers/LayoutDataProvider.tsx +++ b/surfsense_web/components/layout/providers/LayoutDataProvider.tsx @@ -2,7 +2,7 @@ import { useQuery } from "@tanstack/react-query"; import { useAtom, useAtomValue, useSetAtom } from "jotai"; -import { AlertTriangle, Clock3, Inbox, LibraryBig } from "lucide-react"; +import { AlertTriangle, AlarmClock, Inbox, LibraryBig } from "lucide-react"; import { useParams, usePathname, useRouter } from "next/navigation"; import { useTranslations } from "next-intl"; import { useTheme } from "next-themes"; @@ -342,7 +342,7 @@ export function LayoutDataProvider({ searchSpaceId, children }: LayoutDataProvid { title: "Automations", url: `/dashboard/${searchSpaceId}/automations`, - icon: Clock3, + icon: AlarmClock, isActive: isAutomationsActive, }, isMobile diff --git a/surfsense_web/components/new-chat/chat-example-prompts.tsx b/surfsense_web/components/new-chat/chat-example-prompts.tsx index 4fdc32a92..28fa79c9d 100644 --- a/surfsense_web/components/new-chat/chat-example-prompts.tsx +++ b/surfsense_web/components/new-chat/chat-example-prompts.tsx @@ -1,7 +1,7 @@ "use client"; import { - Clock3, + AlarmClock, FilePlus2, Search, Settings2, @@ -22,7 +22,7 @@ interface ChatExamplePromptsProps { const CATEGORY_ICONS: Record = { search: Search, create: FilePlus2, - automate: Clock3, + automate: AlarmClock, tools: Settings2, }; diff --git a/surfsense_web/components/tool-ui/automation/create-automation.tsx b/surfsense_web/components/tool-ui/automation/create-automation.tsx index 644ccd822..5cffdeb6c 100644 --- a/surfsense_web/components/tool-ui/automation/create-automation.tsx +++ b/surfsense_web/components/tool-ui/automation/create-automation.tsx @@ -2,7 +2,7 @@ import type { ToolCallMessagePartProps } from "@assistant-ui/react"; import { useAtomValue } from "jotai"; -import { AlertCircle, Clock3, CornerDownLeftIcon, ExternalLink, Pencil } from "lucide-react"; +import { AlertCircle, AlarmClock, CornerDownLeftIcon, ExternalLink, Pencil } from "lucide-react"; import Link from "next/link"; import { useCallback, useEffect, useMemo, useRef, useState } from "react"; import { @@ -211,7 +211,7 @@ function ApprovalCard({ args, interruptData, onDecision }: ApprovalCardProps) {

- +

{phase === "rejected" @@ -404,7 +404,7 @@ function SavedCard({ result }: { result: SavedResult }) { return (

- +

Automation saved

{result.name}

diff --git a/surfsense_web/contracts/enums/toolIcons.tsx b/surfsense_web/contracts/enums/toolIcons.tsx index 496c26577..98f72796d 100644 --- a/surfsense_web/contracts/enums/toolIcons.tsx +++ b/surfsense_web/contracts/enums/toolIcons.tsx @@ -1,7 +1,7 @@ import { Brain, Calendar, - Clock3, + AlarmClock, FileEdit, FilePlus, FileText, @@ -47,7 +47,7 @@ const TOOL_ICONS: Record = { scrape_webpage: ScanLine, web_search: Globe, // Automations - create_automation: Clock3, + create_automation: AlarmClock, // Memory update_memory: Brain, // Filesystem (built-in deepagent + middleware) From 02070201fb7b2b18fa326e81b947691c6ae02d0c Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Sat, 13 Jun 2026 09:29:38 +0530 Subject: [PATCH 42/59] feat(model-connections): add hint support for API Base URL field and improve dialog accessibility --- .../model-connections/connect-fields.tsx | 17 ++++++++++------- .../model-connections/default-connect-form.tsx | 16 +++++++++++++++- .../provider-connect-dialog.tsx | 15 ++++++++++++--- 3 files changed, 37 insertions(+), 11 deletions(-) diff --git a/surfsense_web/components/settings/model-connections/connect-fields.tsx b/surfsense_web/components/settings/model-connections/connect-fields.tsx index 6ef7a8408..9febb8a1e 100644 --- a/surfsense_web/components/settings/model-connections/connect-fields.tsx +++ b/surfsense_web/components/settings/model-connections/connect-fields.tsx @@ -1,4 +1,5 @@ import { Eye, EyeOff } from "lucide-react"; +import type { ReactNode } from "react"; import { useState } from "react"; import { Button } from "@/components/ui/button"; import { DialogFooter } from "@/components/ui/dialog"; @@ -9,25 +10,27 @@ import { Spinner } from "@/components/ui/spinner"; interface ApiBaseUrlFieldProps { value: string; onChange: (value: string) => void; - optional?: boolean; /** Placeholder, typically the provider's prefilled default base URL. */ placeholder?: string; + hint?: ReactNode; } /** Shared API Base URL input. The prefilled default is passed in via `value`. */ -export function ApiBaseUrlField({ value, onChange, optional, placeholder }: ApiBaseUrlFieldProps) { +export function ApiBaseUrlField({ + value, + onChange, + placeholder, + hint, +}: ApiBaseUrlFieldProps) { return (
- + onChange(event.target.value)} placeholder={placeholder || "https://api.example.com/v1"} /> -

- Local URLs are tested from the backend container, so use host.docker.internal instead of - localhost. -

+ {hint ?

{hint}

: null}
); } diff --git a/surfsense_web/components/settings/model-connections/default-connect-form.tsx b/surfsense_web/components/settings/model-connections/default-connect-form.tsx index 768c0b5da..ce638f5e6 100644 --- a/surfsense_web/components/settings/model-connections/default-connect-form.tsx +++ b/surfsense_web/components/settings/model-connections/default-connect-form.tsx @@ -2,6 +2,19 @@ import { useEffect, useState } from "react"; import { ApiBaseUrlField, ApiKeyField } from "./connect-fields"; import type { ProviderConnectFormProps } from "./provider-metadata"; +function baseUrlHint(provider: string) { + if (provider === "ollama_chat" || provider === "lm_studio") { + return "For local servers, use host.docker.internal instead of localhost."; + } + if (provider === "openai_compatible") { + return "Enter the full endpoint URL."; + } + if (provider === "openai" || provider === "anthropic" || provider === "openrouter") { + return "Override only if you route through a proxy or gateway."; + } + return undefined; +} + /** * Connect form for OpenAI-compatible / native key providers (OpenAI, Anthropic, * OpenRouter, OpenAI-Compatible, LM Studio, Ollama, …). The base URL is @@ -16,6 +29,7 @@ export function DefaultConnectForm({ const [baseUrl, setBaseUrl] = useState(defaultBaseUrl); const [apiKey, setApiKey] = useState(""); const isOllama = provider === "ollama_chat"; + const hint = baseUrlHint(provider); const canSubmit = !(baseUrlRequired && !baseUrl.trim()); useEffect(() => { @@ -27,8 +41,8 @@ export function DefaultConnectForm({ (null); const [currentDraft, setCurrentDraft] = useState({ base_url: null, api_key: null, @@ -99,12 +100,20 @@ export function ProviderConnectDialog({ return ( - + { + event.preventDefault(); + titleRef.current?.focus(); + }} + >
{providerIcon(provider, "size-5")}
- Connect {meta.name} + + Connect {meta.name} + {meta.subtitle}
From 50668775f81a81e319d6cbb47f02a326bba92403 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Sat, 13 Jun 2026 12:14:17 +0530 Subject: [PATCH 43/59] feat(model-selector): enhance model selection with connection scope and free model indication --- .../components/new-chat/model-selector.tsx | 100 ++++++++++-------- 1 file changed, 58 insertions(+), 42 deletions(-) diff --git a/surfsense_web/components/new-chat/model-selector.tsx b/surfsense_web/components/new-chat/model-selector.tsx index 0f5f50849..90226dde5 100644 --- a/surfsense_web/components/new-chat/model-selector.tsx +++ b/surfsense_web/components/new-chat/model-selector.tsx @@ -1,8 +1,10 @@ "use client"; import { useAtom, useAtomValue } from "jotai"; -import { Check, ChevronDown, Cpu, ImageOff, Search, Settings2, Zap } from "lucide-react"; -import { useMemo, useState } from "react"; +import { Check, ChevronDown, Cpu, Search, Settings2, Zap } from "lucide-react"; +import { useRouter } from "next/navigation"; +import type { UIEvent } from "react"; +import { useCallback, useMemo, useState } from "react"; import { updateModelRolesMutationAtom } from "@/atoms/model-connections/model-connections-mutation.atoms"; import { globalModelConnectionsAtom, @@ -23,41 +25,30 @@ import { Input } from "@/components/ui/input"; import { Popover, PopoverContent, PopoverTrigger } from "@/components/ui/popover"; import { Spinner } from "@/components/ui/spinner"; import type { ConnectionRead, ModelRead } from "@/contracts/types/model-connections.types"; -import type { - GlobalImageGenConfig, - GlobalNewLLMConfig, - GlobalVisionLLMConfig, - ImageGenerationConfig, - NewLLMConfigPublic, - VisionLLMConfig, -} from "@/contracts/types/new-llm-config.types"; import { useIsMobile } from "@/hooks/use-mobile"; import { getProviderIcon } from "@/lib/provider-icons"; import { cn } from "@/lib/utils"; +import { providerDisplay } from "../settings/model-connections/provider-metadata"; interface ModelSelectorProps { - onEditLLM: (config: NewLLMConfigPublic | GlobalNewLLMConfig, isGlobal: boolean) => void; - onAddNewLLM: (provider?: string) => void; - onEditImage?: (config: ImageGenerationConfig | GlobalImageGenConfig, isGlobal: boolean) => void; - onAddNewImage?: (provider?: string) => void; - onEditVision?: (config: VisionLLMConfig | GlobalVisionLLMConfig, isGlobal: boolean) => void; - onAddNewVision?: (provider?: string) => void; + searchSpaceId: number; className?: string; } type ChatModel = ModelRead & { connectionId: number; connectionLabel: string; + connectionScope: string; provider: string; }; function modelName(model: ModelRead) { - return model.display_name || model.model_id; + return (model.display_name || model.model_id).replace(/\s+\(free\)$/i, ""); } function connectionLabel(connection: ConnectionRead) { - if (connection.scope === "GLOBAL") return "Hosted"; - return connection.provider; + if (connection.scope === "GLOBAL") return "Global"; + return providerDisplay(connection.provider).name; } function flattenChatModels(connections: ConnectionRead[]) { @@ -68,11 +59,16 @@ function flattenChatModels(connections: ConnectionRead[]) { ...model, connectionId: connection.id, connectionLabel: connectionLabel(connection), + connectionScope: connection.scope, provider: connection.provider, })) ); } +function isFreeGlobalModel(model: ChatModel) { + return model.connectionScope === "GLOBAL" && model.billing_tier?.toLowerCase() === "free"; +} + function groupedModels(models: ChatModel[]) { return models.reduce>((groups, model) => { const key = model.connectionLabel; @@ -83,23 +79,14 @@ function groupedModels(models: ChatModel[]) { } export function ModelSelector({ - onAddNewLLM, - onEditLLM, - onEditImage, - onAddNewImage, - onEditVision, - onAddNewVision, + searchSpaceId, className, }: ModelSelectorProps) { - void onEditLLM; - void onEditImage; - void onAddNewImage; - void onEditVision; - void onAddNewVision; - + const router = useRouter(); const isMobile = useIsMobile(); const [open, setOpen] = useState(false); const [search, setSearch] = useState(""); + const [scrollPos, setScrollPos] = useState<"top" | "middle" | "bottom">("top"); const [{ data: globalConnections = [], isLoading: globalLoading }] = useAtom( globalModelConnectionsAtom ); @@ -130,11 +117,18 @@ export function ModelSelector({ function manageModelConnections() { setOpen(false); - onAddNewLLM(); + router.push(`/dashboard/${searchSpaceId}/search-space-settings/models`); } + const handleScroll = useCallback((event: UIEvent) => { + const el = event.currentTarget; + const atTop = el.scrollTop <= 2; + const atBottom = el.scrollHeight - el.scrollTop - el.clientHeight <= 2; + setScrollPos(atTop ? "top" : atBottom ? "bottom" : "middle"); + }, []); + const content = ( -
+
@@ -146,7 +140,14 @@ export function ModelSelector({ />
-
+
@@ -228,6 +243,7 @@ export function ModelSelector({ size="sm" className={cn( "h-8 min-w-0 gap-2 rounded-md px-3 text-muted-foreground transition-colors", + "select-none", "hover:bg-foreground/10 hover:text-foreground", "data-[state=open]:bg-foreground/10 data-[state=open]:text-foreground", className @@ -263,7 +279,7 @@ export function ModelSelector({ return ( {trigger} - + {content} From bd4a04f2e72a9aebb7a5f614f83bbd43d97487f1 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Sat, 13 Jun 2026 12:45:43 +0530 Subject: [PATCH 44/59] feat(database-migrations): add migration to remove legacy model config tables and remove stale model connection code --- ...38_add_thread_auto_model_pinning_fields.py | 2 +- .../161_remove_legacy_model_configs.py | 270 +++ .../deliverables/tools/generate_image.py | 2 +- .../app/agents/chat/runtime/llm_config.py | 132 +- .../app/automations/services/model_policy.py | 12 +- surfsense_backend/app/config/__init__.py | 50 +- .../app/config/global_llm_config.example.yaml | 80 +- surfsense_backend/app/db.py | 287 +-- .../prompts/default_system_instructions.py | 4 +- .../system_prompt_composer/composer.py | 3 +- surfsense_backend/app/routes/__init__.py | 4 - .../app/routes/image_generation_routes.py | 291 +-- .../app/routes/model_connections_routes.py | 17 +- .../app/routes/new_llm_config_routes.py | 480 ----- .../app/routes/search_spaces_routes.py | 360 +--- .../app/routes/vision_llm_routes.py | 304 ---- surfsense_backend/app/schemas/__init__.py | 45 - .../app/schemas/image_generation.py | 165 +- .../app/schemas/new_llm_config.py | 256 --- surfsense_backend/app/schemas/vision_llm.py | 116 -- .../app/services/auto_model_pin_service.py | 20 +- .../app/services/billable_calls.py | 57 +- surfsense_backend/app/services/llm_service.py | 24 +- .../app/services/model_list_service.py | 2 +- .../openrouter_integration_service.py | 93 +- .../app/services/pricing_registration.py | 6 +- .../app/services/quality_score.py | 2 +- .../app/services/vision_llm_router_service.py | 160 -- .../app/services/vision_model_list_service.py | 134 -- .../scripts/verify_chat_image_capability.py | 43 +- .../builtin/agent_task/test_dependencies.py | 40 +- .../runtime/test_executor_action_ctx.py | 18 +- .../schemas/definition/test_envelope.py | 18 +- .../test_automation_service_policy.py | 120 +- .../automations/services/test_model_policy.py | 62 +- .../routes/test_byok_supports_image_input.py | 110 -- .../routes/test_global_configs_is_premium.py | 184 -- ...t_global_new_llm_configs_supports_image.py | 106 -- .../tests/unit/routes/test_image_gen_quota.py | 62 +- .../services/test_agent_billing_resolver.py | 232 +-- .../services/test_auto_model_pin_service.py | 135 +- .../test_image_gen_api_base_defense.py | 54 +- .../test_openrouter_integration_service.py | 93 +- .../services/test_pricing_registration.py | 74 - .../tests/unit/services/test_quality_score.py | 2 +- .../test_vision_llm_api_base_defense.py | 77 - surfsense_evals/README.md | 4 +- .../parser_compare/run_artifact.json | 2 +- .../src/surfsense_evals/core/cli.py | 73 +- .../core/clients/search_space.py | 116 +- .../src/surfsense_evals/core/config.py | 21 +- .../src/surfsense_evals/core/registry.py | 4 +- .../src/surfsense_evals/core/vision_llm.py | 4 +- .../suites/medical/medxpertqa/runner.py | 2 +- .../multimodal_doc/mmlongbench/runner.py | 2 +- .../multimodal_doc/parser_compare/runner.py | 2 +- .../suites/research/crag/runner.py | 2 +- .../suites/research/frames/runner.py | 2 +- surfsense_evals/tests/core/test_clients.py | 23 +- surfsense_evals/tests/core/test_config.py | 30 +- .../tests/test_integration_smoke.py | 2 +- .../image-models/page.tsx | 6 - .../search-space-settings/roles/page.tsx | 6 - .../vision-models/page.tsx | 6 - .../image-gen-config-mutation.atoms.ts | 96 - .../image-gen-config-query.atoms.ts | 33 - .../new-llm-config-mutation.atoms.ts | 132 -- .../new-llm-config-query.atoms.ts | 98 -- .../vision-llm-config-mutation.atoms.ts | 87 - .../vision-llm-config-query.atoms.ts | 51 - .../components/new-chat/chat-header.tsx | 153 +- .../settings/agent-model-manager.tsx | 423 ----- .../settings/image-model-manager.tsx | 489 ------ .../components/settings/llm-role-manager.tsx | 443 ----- .../settings/vision-model-manager.tsx | 486 ----- .../components/shared/image-config-dialog.tsx | 456 ----- .../components/shared/llm-config-form.tsx | 527 ------ .../components/shared/model-config-dialog.tsx | 339 ---- .../shared/vision-config-dialog.tsx | 478 ----- .../contracts/enums/image-gen-providers.ts | 105 -- surfsense_web/contracts/enums/llm-models.ts | 1558 ----------------- .../contracts/enums/llm-providers.ts | 197 --- .../contracts/enums/vision-providers.ts | 168 -- .../contracts/types/new-llm-config.types.ts | 476 ----- .../lib/apis/image-gen-config-api.service.ts | 81 - .../lib/apis/new-llm-config-api.service.ts | 178 -- .../lib/apis/vision-llm-config-api.service.ts | 63 - surfsense_web/lib/query-client/cache-keys.ts | 19 - surfsense_web/messages/en.json | 8 - surfsense_web/messages/es.json | 31 +- surfsense_web/messages/hi.json | 31 +- surfsense_web/messages/pt.json | 31 +- surfsense_web/messages/zh.json | 46 +- 93 files changed, 956 insertions(+), 11442 deletions(-) create mode 100644 surfsense_backend/alembic/versions/161_remove_legacy_model_configs.py delete mode 100644 surfsense_backend/app/routes/new_llm_config_routes.py delete mode 100644 surfsense_backend/app/routes/vision_llm_routes.py delete mode 100644 surfsense_backend/app/schemas/new_llm_config.py delete mode 100644 surfsense_backend/app/schemas/vision_llm.py delete mode 100644 surfsense_backend/app/services/vision_llm_router_service.py delete mode 100644 surfsense_backend/app/services/vision_model_list_service.py delete mode 100644 surfsense_backend/tests/unit/routes/test_byok_supports_image_input.py delete mode 100644 surfsense_backend/tests/unit/routes/test_global_configs_is_premium.py delete mode 100644 surfsense_backend/tests/unit/routes/test_global_new_llm_configs_supports_image.py delete mode 100644 surfsense_backend/tests/unit/services/test_vision_llm_api_base_defense.py delete mode 100644 surfsense_web/app/dashboard/[search_space_id]/search-space-settings/image-models/page.tsx delete mode 100644 surfsense_web/app/dashboard/[search_space_id]/search-space-settings/roles/page.tsx delete mode 100644 surfsense_web/app/dashboard/[search_space_id]/search-space-settings/vision-models/page.tsx delete mode 100644 surfsense_web/atoms/image-gen-config/image-gen-config-mutation.atoms.ts delete mode 100644 surfsense_web/atoms/image-gen-config/image-gen-config-query.atoms.ts delete mode 100644 surfsense_web/atoms/new-llm-config/new-llm-config-mutation.atoms.ts delete mode 100644 surfsense_web/atoms/new-llm-config/new-llm-config-query.atoms.ts delete mode 100644 surfsense_web/atoms/vision-llm-config/vision-llm-config-mutation.atoms.ts delete mode 100644 surfsense_web/atoms/vision-llm-config/vision-llm-config-query.atoms.ts delete mode 100644 surfsense_web/components/settings/agent-model-manager.tsx delete mode 100644 surfsense_web/components/settings/image-model-manager.tsx delete mode 100644 surfsense_web/components/settings/llm-role-manager.tsx delete mode 100644 surfsense_web/components/settings/vision-model-manager.tsx delete mode 100644 surfsense_web/components/shared/image-config-dialog.tsx delete mode 100644 surfsense_web/components/shared/llm-config-form.tsx delete mode 100644 surfsense_web/components/shared/model-config-dialog.tsx delete mode 100644 surfsense_web/components/shared/vision-config-dialog.tsx delete mode 100644 surfsense_web/contracts/enums/image-gen-providers.ts delete mode 100644 surfsense_web/contracts/enums/llm-models.ts delete mode 100644 surfsense_web/contracts/enums/llm-providers.ts delete mode 100644 surfsense_web/contracts/enums/vision-providers.ts delete mode 100644 surfsense_web/contracts/types/new-llm-config.types.ts delete mode 100644 surfsense_web/lib/apis/image-gen-config-api.service.ts delete mode 100644 surfsense_web/lib/apis/new-llm-config-api.service.ts delete mode 100644 surfsense_web/lib/apis/vision-llm-config-api.service.ts diff --git a/surfsense_backend/alembic/versions/138_add_thread_auto_model_pinning_fields.py b/surfsense_backend/alembic/versions/138_add_thread_auto_model_pinning_fields.py index fba621a0c..8c74b637b 100644 --- a/surfsense_backend/alembic/versions/138_add_thread_auto_model_pinning_fields.py +++ b/surfsense_backend/alembic/versions/138_add_thread_auto_model_pinning_fields.py @@ -4,7 +4,7 @@ Revision ID: 138 Revises: 137 Create Date: 2026-04-30 -Add a single thread-level column to persist the Auto (Fastest) model pin: +Add a single thread-level column to persist the Auto model pin: - pinned_llm_config_id: concrete resolved global LLM config id used for this thread. NULL means "no pin; Auto will resolve on next turn". diff --git a/surfsense_backend/alembic/versions/161_remove_legacy_model_configs.py b/surfsense_backend/alembic/versions/161_remove_legacy_model_configs.py new file mode 100644 index 000000000..2108d763c --- /dev/null +++ b/surfsense_backend/alembic/versions/161_remove_legacy_model_configs.py @@ -0,0 +1,270 @@ +"""remove legacy model config tables + +Revision ID: 161 +Revises: 160 +""" + +from collections.abc import Sequence + +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql +from sqlalchemy.types import TypeEngine + +from alembic import op + +revision: str = "161" +down_revision: str | None = "160" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +litellm_provider = postgresql.ENUM( + "OPENAI", + "ANTHROPIC", + "GOOGLE", + "AZURE_OPENAI", + "BEDROCK", + "VERTEX_AI", + "GROQ", + "COHERE", + "MISTRAL", + "DEEPSEEK", + "XAI", + "OPENROUTER", + "TOGETHER_AI", + "FIREWORKS_AI", + "REPLICATE", + "PERPLEXITY", + "OLLAMA", + "ALIBABA_QWEN", + "MOONSHOT", + "ZHIPU", + "ANYSCALE", + "DEEPINFRA", + "CEREBRAS", + "SAMBANOVA", + "AI21", + "CLOUDFLARE", + "DATABRICKS", + "COMETAPI", + "HUGGINGFACE", + "GITHUB_MODELS", + "MINIMAX", + "CUSTOM", + name="litellmprovider", + create_type=False, +) +image_gen_provider = postgresql.ENUM( + "OPENAI", + "AZURE_OPENAI", + "GOOGLE", + "VERTEX_AI", + "BEDROCK", + "RECRAFT", + "OPENROUTER", + "XINFERENCE", + "NSCALE", + name="imagegenprovider", + create_type=False, +) +vision_provider = postgresql.ENUM( + "OPENAI", + "ANTHROPIC", + "GOOGLE", + "AZURE_OPENAI", + "VERTEX_AI", + "BEDROCK", + "XAI", + "OPENROUTER", + "OLLAMA", + "GROQ", + "TOGETHER_AI", + "FIREWORKS_AI", + "DEEPSEEK", + "MISTRAL", + "CUSTOM", + name="visionprovider", + create_type=False, +) + + +def _table_exists(table_name: str) -> bool: + return table_name in sa.inspect(op.get_bind()).get_table_names() + + +def _column_exists(table_name: str, column_name: str) -> bool: + if not _table_exists(table_name): + return False + return column_name in { + column["name"] for column in sa.inspect(op.get_bind()).get_columns(table_name) + } + + +def _drop_column_if_exists(table_name: str, column_name: str) -> None: + if _column_exists(table_name, column_name): + op.drop_column(table_name, column_name) + + +def _rename_column_if_exists( + table_name: str, + old_column_name: str, + new_column_name: str, + *, + existing_type: TypeEngine, + existing_nullable: bool = True, +) -> None: + if _column_exists(table_name, old_column_name) and not _column_exists( + table_name, new_column_name + ): + op.alter_column( + table_name, + old_column_name, + new_column_name=new_column_name, + existing_type=existing_type, + existing_nullable=existing_nullable, + ) + + +def upgrade() -> None: + for table_name in ( + "new_llm_configs", + "vision_llm_configs", + "image_generation_configs", + ): + if _table_exists(table_name): + op.drop_table(table_name) + + _drop_column_if_exists("searchspaces", "agent_llm_id") + _drop_column_if_exists("searchspaces", "image_generation_config_id") + _drop_column_if_exists("searchspaces", "vision_llm_config_id") + + _rename_column_if_exists( + "image_generations", + "image_generation_config_id", + "image_gen_model_id", + existing_type=sa.Integer(), + ) + + op.execute("DROP TYPE IF EXISTS litellmprovider") + op.execute("DROP TYPE IF EXISTS imagegenprovider") + op.execute("DROP TYPE IF EXISTS visionprovider") + + +def downgrade() -> None: + bind = op.get_bind() + litellm_provider.create(bind, checkfirst=True) + image_gen_provider.create(bind, checkfirst=True) + vision_provider.create(bind, checkfirst=True) + + _rename_column_if_exists( + "image_generations", + "image_gen_model_id", + "image_generation_config_id", + existing_type=sa.Integer(), + ) + + if _table_exists("searchspaces"): + if not _column_exists("searchspaces", "agent_llm_id"): + op.add_column( + "searchspaces", + sa.Column("agent_llm_id", sa.Integer(), nullable=True), + ) + if not _column_exists("searchspaces", "image_generation_config_id"): + op.add_column( + "searchspaces", + sa.Column("image_generation_config_id", sa.Integer(), nullable=True), + ) + if not _column_exists("searchspaces", "vision_llm_config_id"): + op.add_column( + "searchspaces", + sa.Column("vision_llm_config_id", sa.Integer(), nullable=True), + ) + + if not _table_exists("image_generation_configs"): + op.create_table( + "image_generation_configs", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("created_at", sa.DateTime(timezone=True), nullable=False), + sa.Column("name", sa.String(length=100), nullable=False), + sa.Column("description", sa.String(length=500), nullable=True), + sa.Column("provider", image_gen_provider, nullable=False), + sa.Column("custom_provider", sa.String(length=100), nullable=True), + sa.Column("model_name", sa.String(length=100), nullable=False), + sa.Column("api_key", sa.String(), nullable=False), + sa.Column("api_base", sa.String(length=500), nullable=True), + sa.Column("api_version", sa.String(length=50), nullable=True), + sa.Column("litellm_params", sa.JSON(), nullable=True), + sa.Column("search_space_id", sa.Integer(), nullable=False), + sa.Column("user_id", postgresql.UUID(as_uuid=True), nullable=False), + sa.ForeignKeyConstraint( + ["search_space_id"], ["searchspaces.id"], ondelete="CASCADE" + ), + sa.ForeignKeyConstraint(["user_id"], ["user.id"], ondelete="CASCADE"), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index( + op.f("ix_image_generation_configs_name"), + "image_generation_configs", + ["name"], + unique=False, + ) + + if not _table_exists("vision_llm_configs"): + op.create_table( + "vision_llm_configs", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("created_at", sa.DateTime(timezone=True), nullable=False), + sa.Column("name", sa.String(length=100), nullable=False), + sa.Column("description", sa.String(length=500), nullable=True), + sa.Column("provider", vision_provider, nullable=False), + sa.Column("custom_provider", sa.String(length=100), nullable=True), + sa.Column("model_name", sa.String(length=100), nullable=False), + sa.Column("api_key", sa.String(), nullable=False), + sa.Column("api_base", sa.String(length=500), nullable=True), + sa.Column("api_version", sa.String(length=50), nullable=True), + sa.Column("litellm_params", sa.JSON(), nullable=True), + sa.Column("search_space_id", sa.Integer(), nullable=False), + sa.Column("user_id", postgresql.UUID(as_uuid=True), nullable=False), + sa.ForeignKeyConstraint( + ["search_space_id"], ["searchspaces.id"], ondelete="CASCADE" + ), + sa.ForeignKeyConstraint(["user_id"], ["user.id"], ondelete="CASCADE"), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index( + op.f("ix_vision_llm_configs_name"), + "vision_llm_configs", + ["name"], + unique=False, + ) + + if not _table_exists("new_llm_configs"): + op.create_table( + "new_llm_configs", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("created_at", sa.DateTime(timezone=True), nullable=False), + sa.Column("name", sa.String(length=100), nullable=False), + sa.Column("description", sa.String(length=500), nullable=True), + sa.Column("provider", litellm_provider, nullable=False), + sa.Column("custom_provider", sa.String(length=100), nullable=True), + sa.Column("model_name", sa.String(length=100), nullable=False), + sa.Column("api_key", sa.String(), nullable=False), + sa.Column("api_base", sa.String(length=500), nullable=True), + sa.Column("litellm_params", sa.JSON(), nullable=True), + sa.Column("system_instructions", sa.Text(), nullable=False), + sa.Column("use_default_system_instructions", sa.Boolean(), nullable=False), + sa.Column("citations_enabled", sa.Boolean(), nullable=False), + sa.Column("search_space_id", sa.Integer(), nullable=False), + sa.Column("user_id", postgresql.UUID(as_uuid=True), nullable=False), + sa.ForeignKeyConstraint( + ["search_space_id"], ["searchspaces.id"], ondelete="CASCADE" + ), + sa.ForeignKeyConstraint(["user_id"], ["user.id"], ondelete="CASCADE"), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index( + op.f("ix_new_llm_configs_name"), + "new_llm_configs", + ["name"], + unique=False, + ) diff --git a/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/deliverables/tools/generate_image.py b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/deliverables/tools/generate_image.py index 505831faa..d847e021a 100644 --- a/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/deliverables/tools/generate_image.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/deliverables/tools/generate_image.py @@ -215,7 +215,7 @@ def create_generate_image_tool( prompt=prompt, model=getattr(response, "_hidden_params", {}).get("model"), n=n, - image_generation_config_id=config_id, + image_gen_model_id=config_id, response_data=response_dict, search_space_id=search_space_id, access_token=access_token, diff --git a/surfsense_backend/app/agents/chat/runtime/llm_config.py b/surfsense_backend/app/agents/chat/runtime/llm_config.py index efc188df8..e00d16ee8 100644 --- a/surfsense_backend/app/agents/chat/runtime/llm_config.py +++ b/surfsense_backend/app/agents/chat/runtime/llm_config.py @@ -24,8 +24,6 @@ from langchain_core.messages import AIMessage, BaseMessage from langchain_core.outputs import ChatGenerationChunk, ChatResult from langchain_litellm import ChatLiteLLM from litellm import get_model_info -from sqlalchemy import select -from sqlalchemy.ext.asyncio import AsyncSession from app.agents.chat.runtime.prompt_caching import ( apply_litellm_prompt_caching, @@ -34,7 +32,6 @@ from app.services.llm_router_service import ( AUTO_MODE_ID, ChatLiteLLMRouter, _sanitize_content, - is_auto_mode, ) @@ -130,7 +127,7 @@ class AgentConfig: """ Complete configuration for the SurfSense agent. - This combines LLM settings with prompt configuration from NewLLMConfig. + This combines resolved model settings with prompt configuration. Supports Auto mode metadata (ID 0). Runtime callers must resolve Auto to a concrete global or BYOK model before constructing ChatLiteLLM. """ @@ -180,7 +177,7 @@ class AgentConfig: use_default_system_instructions=True, citations_enabled=True, config_id=AUTO_MODE_ID, - config_name="Auto (Fastest)", + config_name="Auto", is_auto_mode=True, billing_tier="free", is_premium=False, @@ -191,57 +188,12 @@ class AgentConfig: supports_image_input=True, ) - @classmethod - def from_new_llm_config(cls, config) -> "AgentConfig": - """Build an AgentConfig from a NewLLMConfig database model.""" - # Lazy import: keeps provider_capabilities (and litellm) out of init order. - from app.services.provider_capabilities import derive_supports_image_input - - provider_value = ( - config.provider.value - if hasattr(config.provider, "value") - else str(config.provider) - ) - litellm_params = config.litellm_params or {} - base_model = ( - litellm_params.get("base_model") - if isinstance(litellm_params, dict) - else None - ) - - return cls( - provider=provider_value, - model_name=config.model_name, - api_key=config.api_key, - api_base=config.api_base, - custom_provider=config.custom_provider, - litellm_params=config.litellm_params, - system_instructions=config.system_instructions, - use_default_system_instructions=config.use_default_system_instructions, - citations_enabled=config.citations_enabled, - config_id=config.id, - config_name=config.name, - is_auto_mode=False, - billing_tier="free", - is_premium=False, - anonymous_enabled=False, - quota_reserve_tokens=None, - # BYOK rows have no curated flag; ask LiteLLM (default-allow on - # unknown). The streaming safety net still blocks explicit text-only. - supports_image_input=derive_supports_image_input( - provider=provider_value.lower(), - model_name=config.model_name, - base_model=base_model, - custom_provider=config.custom_provider, - ), - ) - @classmethod def from_yaml_config(cls, yaml_config: dict) -> "AgentConfig": """Build an AgentConfig from a YAML configuration dictionary. - Supports the same prompt fields as NewLLMConfig (system_instructions, - use_default_system_instructions, citations_enabled). + Supports prompt fields such as system_instructions, + use_default_system_instructions, and citations_enabled. """ # Lazy import: keeps provider_capabilities (and litellm) out of init order. from app.services.provider_capabilities import derive_supports_image_input @@ -334,82 +286,6 @@ def load_global_llm_config_by_id(llm_config_id: int) -> dict | None: return load_llm_config_from_yaml(llm_config_id) -async def load_new_llm_config_from_db( - session: AsyncSession, - config_id: int, -) -> "AgentConfig | None": - """Load a NewLLMConfig from the database by ID.""" - from app.db import NewLLMConfig - - try: - result = await session.execute( - select(NewLLMConfig).filter(NewLLMConfig.id == config_id) - ) - config = result.scalars().first() - - if not config: - print(f"Error: NewLLMConfig with id {config_id} not found") - return None - - return AgentConfig.from_new_llm_config(config) - except Exception as e: - print(f"Error loading NewLLMConfig from database: {e}") - return None - - -async def load_agent_llm_config_for_search_space( - session: AsyncSession, - search_space_id: int, -) -> "AgentConfig | None": - """Load the chat model config for a search space via its agent_llm_id. - - Positive id -> DB; negative -> YAML; None -> first global config (-1). - """ - from app.db import SearchSpace - - try: - result = await session.execute( - select(SearchSpace).filter(SearchSpace.id == search_space_id) - ) - search_space = result.scalars().first() - - if not search_space: - print(f"Error: SearchSpace with id {search_space_id} not found") - return None - - config_id = ( - search_space.agent_llm_id if search_space.agent_llm_id is not None else -1 - ) - return await load_agent_config(session, config_id, search_space_id) - except Exception as e: - print(f"Error loading chat model config for search space {search_space_id}: {e}") - return None - - -async def load_agent_config( - session: AsyncSession, - config_id: int, - search_space_id: int | None = None, -) -> "AgentConfig | None": - """Main config loader: id 0 -> Auto mode; negative -> YAML; positive -> DB.""" - if is_auto_mode(config_id): - return AgentConfig.from_auto_mode() - - if config_id < 0: - # In-memory covers static YAML + dynamic OpenRouter configs. - from app.config import config as app_config - - for cfg in app_config.GLOBAL_LLM_CONFIGS: - if cfg.get("id") == config_id: - return AgentConfig.from_yaml_config(cfg) - yaml_config = load_llm_config_from_yaml(config_id) - if yaml_config: - return AgentConfig.from_yaml_config(yaml_config) - return None - else: - return await load_new_llm_config_from_db(session, config_id) - - def create_chat_litellm_from_config(llm_config: dict) -> ChatLiteLLM | None: """Create a ChatLiteLLM instance from a global LLM config dictionary.""" if llm_config.get("custom_provider"): diff --git a/surfsense_backend/app/automations/services/model_policy.py b/surfsense_backend/app/automations/services/model_policy.py index b160fc78d..e18264246 100644 --- a/surfsense_backend/app/automations/services/model_policy.py +++ b/surfsense_backend/app/automations/services/model_policy.py @@ -2,11 +2,11 @@ Automations run unattended, so every run must be **billable**: it may only use either a premium global model (``billing_tier == "premium"``) or a user-provided -BYOK model (a positive config id pointing at a per-user/per-space DB row). Free +BYOK model (a positive model id pointing at a per-user/per-space DB row). Free global models and Auto mode are blocked, because Auto can dispatch to a free deployment and free models aren't metered in premium credits. -Config id conventions (shared across chat / image / vision): +Model id conventions (shared across chat / image / vision): - ``id == 0`` → Auto mode (``AUTO_MODE_ID`` / ``IMAGE_GEN_AUTO_MODE_ID`` / ``VISION_AUTO_MODE_ID``). Blocked. - ``id < 0`` → global YAML/OpenRouter config. Allowed only if premium. @@ -82,7 +82,7 @@ def get_model_eligibility( The ID-based core shared by both the search-space path (creation/eligibility) and the captured-snapshot path (runtime backstop). Each violation is - ``{"kind", "config_id", "reason"}``. + ``{"kind", "model_id", "reason"}``. """ checks: list[tuple[ModelKind, int | None]] = [ ("chat", chat_model_id), @@ -91,10 +91,10 @@ def get_model_eligibility( ] violations: list[dict] = [] - for kind, config_id in checks: - allowed, reason = _classify(kind, config_id) + for kind, model_id in checks: + allowed, reason = _classify(kind, model_id) if not allowed: - violations.append({"kind": kind, "model_id": config_id, "reason": reason}) + violations.append({"kind": kind, "model_id": model_id, "reason": reason}) return {"allowed": not violations, "violations": violations} diff --git a/surfsense_backend/app/config/__init__.py b/surfsense_backend/app/config/__init__.py index d690c1d7d..8c9662aa8 100644 --- a/surfsense_backend/app/config/__init__.py +++ b/surfsense_backend/app/config/__init__.py @@ -119,7 +119,7 @@ def load_global_llm_configs(): else: seen_slugs[slug] = cfg.get("id", 0) - # Stamp Auto (Fastest) ranking metadata. YAML configs are always + # Stamp Auto ranking metadata. YAML configs are always # Tier A — operator-curated, locked first when premium-eligible. # The OpenRouter refresh tick later re-stamps health for any cfg # whose provider == "openrouter" via _enrich_health. @@ -210,42 +210,6 @@ def load_global_image_gen_configs(): return [] -def load_global_vision_llm_configs(): - data = _global_config_data() - if not data: - return [] - - try: - configs = copy.deepcopy(data.get("global_vision_llm_configs", []) or []) - for cfg in configs: - if isinstance(cfg, dict): - cfg.setdefault("billing_tier", "free") - return configs - except Exception as e: - print(f"Warning: Failed to load global vision LLM configs: {e}") - return [] - - -def load_vision_llm_router_settings(): - default_settings = { - "routing_strategy": "usage-based-routing", - "num_retries": 3, - "allowed_fails": 3, - "cooldown_time": 60, - } - - data = _global_config_data() - if not data: - return default_settings - - try: - settings = data.get("vision_llm_router_settings", {}) - return {**default_settings, **settings} - except Exception as e: - print(f"Warning: Failed to load vision LLM router settings: {e}") - return default_settings - - def load_image_gen_router_settings(): """ Load router settings for image generation Auto mode from YAML file. @@ -482,12 +446,6 @@ def initialize_image_gen_router(): print(f"Warning: Failed to initialize Image Generation Router: {e}") -def initialize_vision_llm_router(): - # Retired: vision Auto now uses shared capability-filtered model selection - # over GLOBAL/BYOK chat models with supports_image_input=true. - return - - class Config: # Check if ffmpeg is installed if not is_ffmpeg_installed(): @@ -869,12 +827,6 @@ class Config: # Router settings for Image Generation Auto mode IMAGE_GEN_ROUTER_SETTINGS = load_image_gen_router_settings() - # Global Vision LLM Configurations (optional) - GLOBAL_VISION_LLM_CONFIGS = load_global_vision_llm_configs() - - # Router settings for Vision LLM Auto mode - VISION_LLM_ROUTER_SETTINGS = load_vision_llm_router_settings() - # Virtual GLOBAL connection/model catalog. This is server-only metadata # derived from global_llm_config.yaml; GLOBAL keys are not stored in DB. from app.services.global_model_catalog import ( diff --git a/surfsense_backend/app/config/global_llm_config.example.yaml b/surfsense_backend/app/config/global_llm_config.example.yaml index 06676511f..c5b65fee0 100644 --- a/surfsense_backend/app/config/global_llm_config.example.yaml +++ b/surfsense_backend/app/config/global_llm_config.example.yaml @@ -433,87 +433,11 @@ global_image_generation_configs: # rpm: 30 # litellm_params: {} -# ============================================================================= -# Vision LLM Configuration -# ============================================================================= -# These configurations power the vision autocomplete feature (screenshot analysis). -# Only vision-capable models should be used here (e.g. GPT-4o, Gemini Pro, Claude 3). -# Supported providers: OpenAI, Anthropic, Google, Azure OpenAI, Vertex AI, Bedrock, -# xAI, OpenRouter, Ollama, Groq, Together AI, Fireworks AI, DeepSeek, Mistral, Custom -# -# Auto mode (ID 0) uses LiteLLM Router for load balancing across all vision configs. - -# Router Settings for Vision LLM Auto Mode -vision_llm_router_settings: - routing_strategy: "usage-based-routing" - num_retries: 3 - allowed_fails: 3 - cooldown_time: 60 - -global_vision_llm_configs: - # Example: OpenAI GPT-4o (recommended for vision) - - id: -1001 - name: "Global GPT-4o Vision" - description: "OpenAI's GPT-4o with strong vision capabilities" - litellm_provider: "openai" - model_name: "gpt-4o" - api_key: "sk-your-openai-api-key-here" - api_base: "https://api.openai.com/v1" - rpm: 500 - tpm: 100000 - litellm_params: - temperature: 0.3 - max_tokens: 1000 - - # Example: Google Gemini 2.0 Flash - - id: -1002 - name: "Global Gemini 2.0 Flash" - description: "Google's fast vision model with large context" - litellm_provider: "gemini" - model_name: "gemini-2.0-flash" - api_key: "your-google-ai-api-key-here" - api_base: "https://generativelanguage.googleapis.com/v1beta" - rpm: 1000 - tpm: 200000 - litellm_params: - temperature: 0.3 - max_tokens: 1000 - - # Example: Anthropic Claude 3.5 Sonnet - - id: -1003 - name: "Global Claude 3.5 Sonnet Vision" - description: "Anthropic's Claude 3.5 Sonnet with vision support" - litellm_provider: "anthropic" - model_name: "claude-3-5-sonnet-20241022" - api_key: "sk-ant-your-anthropic-api-key-here" - api_base: "https://api.anthropic.com/v1" - rpm: 1000 - tpm: 100000 - litellm_params: - temperature: 0.3 - max_tokens: 1000 - - # Example: Azure OpenAI GPT-4o - # - id: -1004 - # name: "Global Azure GPT-4o Vision" - # description: "Azure-hosted GPT-4o for vision analysis" - # litellm_provider: "azure" - # model_name: "azure/gpt-4o-deployment" - # api_key: "your-azure-api-key-here" - # api_base: "https://your-resource.openai.azure.com" - # api_version: "2024-02-15-preview" - # rpm: 500 - # tpm: 100000 - # litellm_params: - # temperature: 0.3 - # max_tokens: 1000 - # base_model: "gpt-4o" - # Notes: # - ID 0 is reserved for "Auto" mode - uses LiteLLM Router for load balancing # - Use negative IDs to distinguish global models from BYOK/local DB models -# - IDs must be unique across chat, vision, and image generation configs -# - Suggested static ranges: chat -1..-999, vision -1001..-1999, image -2001..-2999 +# - IDs must be unique across chat and image generation configs +# - Suggested static ranges: chat -1..-999, image -2001..-2999 # - The 'api_key' field will not be exposed to users via API # - system_instructions: Custom prompt or empty string to use defaults # - use_default_system_instructions: true = use SURFSENSE_SYSTEM_INSTRUCTIONS when system_instructions is empty diff --git a/surfsense_backend/app/db.py b/surfsense_backend/app/db.py index 728031fa0..38d0ffe33 100644 --- a/surfsense_backend/app/db.py +++ b/surfsense_backend/app/db.py @@ -198,81 +198,6 @@ class DocumentStatus: return None -class LiteLLMProvider(StrEnum): - """ - Enum for LLM providers supported by LiteLLM. - """ - - OPENAI = "OPENAI" - ANTHROPIC = "ANTHROPIC" - GOOGLE = "GOOGLE" - AZURE_OPENAI = "AZURE_OPENAI" - BEDROCK = "BEDROCK" - VERTEX_AI = "VERTEX_AI" - GROQ = "GROQ" - COHERE = "COHERE" - MISTRAL = "MISTRAL" - DEEPSEEK = "DEEPSEEK" - XAI = "XAI" - OPENROUTER = "OPENROUTER" - TOGETHER_AI = "TOGETHER_AI" - FIREWORKS_AI = "FIREWORKS_AI" - REPLICATE = "REPLICATE" - PERPLEXITY = "PERPLEXITY" - OLLAMA = "OLLAMA" - ALIBABA_QWEN = "ALIBABA_QWEN" - MOONSHOT = "MOONSHOT" - ZHIPU = "ZHIPU" - ANYSCALE = "ANYSCALE" - DEEPINFRA = "DEEPINFRA" - CEREBRAS = "CEREBRAS" - SAMBANOVA = "SAMBANOVA" - AI21 = "AI21" - CLOUDFLARE = "CLOUDFLARE" - DATABRICKS = "DATABRICKS" - COMETAPI = "COMETAPI" - HUGGINGFACE = "HUGGINGFACE" - GITHUB_MODELS = "GITHUB_MODELS" - MINIMAX = "MINIMAX" - CUSTOM = "CUSTOM" - - -class ImageGenProvider(StrEnum): - """ - Enum for image generation providers supported by LiteLLM. - This is a subset of LLM providers — only those that support image generation. - See: https://docs.litellm.ai/docs/image_generation#supported-providers - """ - - OPENAI = "OPENAI" - AZURE_OPENAI = "AZURE_OPENAI" - GOOGLE = "GOOGLE" # Google AI Studio - VERTEX_AI = "VERTEX_AI" - BEDROCK = "BEDROCK" # AWS Bedrock - RECRAFT = "RECRAFT" - OPENROUTER = "OPENROUTER" - XINFERENCE = "XINFERENCE" - NSCALE = "NSCALE" - - -class VisionProvider(StrEnum): - OPENAI = "OPENAI" - ANTHROPIC = "ANTHROPIC" - GOOGLE = "GOOGLE" - AZURE_OPENAI = "AZURE_OPENAI" - VERTEX_AI = "VERTEX_AI" - BEDROCK = "BEDROCK" - XAI = "XAI" - OPENROUTER = "OPENROUTER" - OLLAMA = "OLLAMA" - GROQ = "GROQ" - TOGETHER_AI = "TOGETHER_AI" - FIREWORKS_AI = "FIREWORKS_AI" - DEEPSEEK = "DEEPSEEK" - MISTRAL = "MISTRAL" - CUSTOM = "CUSTOM" - - class ConnectionScope(StrEnum): GLOBAL = "GLOBAL" SEARCH_SPACE = "SEARCH_SPACE" @@ -710,11 +635,11 @@ class NewChatThread(BaseModel, TimestampMixin): default=False, server_default="false", ) - # Auto (Fastest) model pin for this thread: concrete resolved global LLM + # Auto model pin for this thread: concrete resolved global LLM # config id. NULL means no pin; Auto will resolve on the next turn. # Single-writer invariant: only app.services.auto_model_pin_service sets # or clears this column (plus bulk clears when a search space's - # agent_llm_id changes). Unindexed: all reads are by primary key. + # chat_model_id changes). Unindexed: all reads are by primary key. pinned_llm_config_id = Column(Integer, nullable=True) # Surface metadata for first-party SurfSense and external chat threads. @@ -1686,75 +1611,6 @@ class Model(BaseModel, TimestampMixin): ) -class ImageGenerationConfig(BaseModel, TimestampMixin): - """ - Dedicated configuration table for image generation models. - - Separate from NewLLMConfig because image generation models don't need - system_instructions, citations_enabled, or use_default_system_instructions. - They only need provider credentials and model parameters. - """ - - __tablename__ = "image_generation_configs" - - name = Column(String(100), nullable=False, index=True) - description = Column(String(500), nullable=True) - - # Provider & model (uses ImageGenProvider, NOT LiteLLMProvider) - provider = Column(SQLAlchemyEnum(ImageGenProvider), nullable=False) - custom_provider = Column(String(100), nullable=True) - model_name = Column(String(100), nullable=False) - - # Credentials - api_key = Column(String, nullable=False) - api_base = Column(String(500), nullable=True) - api_version = Column(String(50), nullable=True) # Azure-specific - - # Additional litellm parameters - litellm_params = Column(JSON, nullable=True, default={}) - - # Relationships - search_space_id = Column( - Integer, ForeignKey("searchspaces.id", ondelete="CASCADE"), nullable=False - ) - search_space = relationship( - "SearchSpace", back_populates="image_generation_configs" - ) - - # User who created this config - user_id = Column( - UUID(as_uuid=True), ForeignKey("user.id", ondelete="CASCADE"), nullable=False - ) - user = relationship("User", back_populates="image_generation_configs") - - -class VisionLLMConfig(BaseModel, TimestampMixin): - __tablename__ = "vision_llm_configs" - - name = Column(String(100), nullable=False, index=True) - description = Column(String(500), nullable=True) - - provider = Column(SQLAlchemyEnum(VisionProvider), nullable=False) - custom_provider = Column(String(100), nullable=True) - model_name = Column(String(100), nullable=False) - - api_key = Column(String, nullable=False) - api_base = Column(String(500), nullable=True) - api_version = Column(String(50), nullable=True) - - litellm_params = Column(JSON, nullable=True, default={}) - - search_space_id = Column( - Integer, ForeignKey("searchspaces.id", ondelete="CASCADE"), nullable=False - ) - search_space = relationship("SearchSpace", back_populates="vision_llm_configs") - - user_id = Column( - UUID(as_uuid=True), ForeignKey("user.id", ondelete="CASCADE"), nullable=False - ) - user = relationship("User", back_populates="vision_llm_configs") - - class ImageGeneration(BaseModel, TimestampMixin): """ Stores image generation requests and results using litellm.aimage_generation(). @@ -1786,10 +1642,9 @@ class ImageGeneration(BaseModel, TimestampMixin): style = Column(String(50), nullable=True) # Model-specific style parameter response_format = Column(String(50), nullable=True) # "url" or "b64_json" - # Image generation config reference - # 0 = Auto mode (router), negative IDs = global configs from YAML, - # positive IDs = ImageGenerationConfig records in DB - image_generation_config_id = Column(Integer, nullable=True) + # Image generation model provenance. + # 0 = Auto mode, negative IDs = GLOBAL models, positive IDs = Model records. + image_gen_model_id = Column(Integer, nullable=True) # Response data (full litellm response as JSONB) — present on success response_data = Column(JSONB, nullable=True) @@ -1831,23 +1686,7 @@ class SearchSpace(BaseModel, TimestampMixin): shared_memory_md = Column(Text, nullable=True, server_default="") - # Search space-level LLM preferences (shared by all members) - # Note: ID values: - # - 0: Auto mode (uses LiteLLM Router for load balancing) - default for new search spaces - # - Negative IDs: Global configs from YAML - # - Positive IDs: Custom configs from DB (NewLLMConfig table) - agent_llm_id = Column( - Integer, nullable=True, default=0 - ) # For chat operations, defaults to Auto mode - image_generation_config_id = Column( - Integer, nullable=True, default=0 - ) # For image generation, defaults to Auto mode - vision_llm_config_id = Column( - Integer, nullable=True, default=0 - ) # For vision/screenshot analysis, defaults to Auto mode - - # New connection/model role bindings. These supersede the legacy config - # columns above without removing them in this PR. + # Connection/model role bindings. # Note: ID values preserve the existing convention: # - 0: Auto mode # - Negative IDs: Global virtual models from global_llm_config.yaml @@ -1931,24 +1770,6 @@ class SearchSpace(BaseModel, TimestampMixin): order_by="SearchSourceConnector.id", cascade="all, delete-orphan", ) - new_llm_configs = relationship( - "NewLLMConfig", - back_populates="search_space", - order_by="NewLLMConfig.id", - cascade="all, delete-orphan", - ) - image_generation_configs = relationship( - "ImageGenerationConfig", - back_populates="search_space", - order_by="ImageGenerationConfig.id", - cascade="all, delete-orphan", - ) - vision_llm_configs = relationship( - "VisionLLMConfig", - back_populates="search_space", - order_by="VisionLLMConfig.id", - cascade="all, delete-orphan", - ) connections = relationship( "Connection", back_populates="search_space", @@ -2057,64 +1878,6 @@ class SearchSourceConnector(BaseModel, TimestampMixin): documents = relationship("Document", back_populates="connector") -class NewLLMConfig(BaseModel, TimestampMixin): - """ - New LLM configuration table that combines model settings with prompt configuration. - - This table provides: - - LLM model configuration (provider, model_name, api_key, etc.) - - Configurable system instructions (defaults to SURFSENSE_SYSTEM_INSTRUCTIONS) - - Citation toggle (enable/disable citation instructions) - - Note: Tools instructions are built by get_tools_instructions(thread_visibility) (personal vs shared memory). - """ - - __tablename__ = "new_llm_configs" - - name = Column(String(100), nullable=False, index=True) - description = Column(String(500), nullable=True) - - # === LLM Model Configuration (from original LLMConfig, excluding 'language') === - # Provider from the enum - provider = Column(SQLAlchemyEnum(LiteLLMProvider), nullable=False) - # Custom provider name when provider is CUSTOM - custom_provider = Column(String(100), nullable=True) - # Just the model name without provider prefix - model_name = Column(String(100), nullable=False) - # API Key should be encrypted before storing - api_key = Column(String, nullable=False) - api_base = Column(String(500), nullable=True) - # For any other parameters that litellm supports - litellm_params = Column(JSON, nullable=True, default={}) - - # === Prompt Configuration === - # Configurable system instructions (defaults to SURFSENSE_SYSTEM_INSTRUCTIONS) - # Users can customize this from the UI - system_instructions = Column( - Text, - nullable=False, - default="", # Empty string means use default SURFSENSE_SYSTEM_INSTRUCTIONS - ) - # Whether to use the default system instructions when system_instructions is empty - use_default_system_instructions = Column(Boolean, nullable=False, default=True) - - # Citation toggle - when enabled, SURFSENSE_CITATION_INSTRUCTIONS is injected - # When disabled, an anti-citation prompt is injected instead - citations_enabled = Column(Boolean, nullable=False, default=True) - - # === Relationships === - search_space_id = Column( - Integer, ForeignKey("searchspaces.id", ondelete="CASCADE"), nullable=False - ) - search_space = relationship("SearchSpace", back_populates="new_llm_configs") - - # User who created this config - user_id = Column( - UUID(as_uuid=True), ForeignKey("user.id", ondelete="CASCADE"), nullable=False - ) - user = relationship("User", back_populates="new_llm_configs") - - class Log(BaseModel, TimestampMixin): __tablename__ = "logs" @@ -2481,25 +2244,6 @@ if config.AUTH_TYPE == "GOOGLE": passive_deletes=True, ) - # LLM configs created by this user - new_llm_configs = relationship( - "NewLLMConfig", - back_populates="user", - passive_deletes=True, - ) - - # Image generation configs created by this user - image_generation_configs = relationship( - "ImageGenerationConfig", - back_populates="user", - passive_deletes=True, - ) - - vision_llm_configs = relationship( - "VisionLLMConfig", - back_populates="user", - passive_deletes=True, - ) connections = relationship( "Connection", back_populates="user", @@ -2632,25 +2376,6 @@ else: passive_deletes=True, ) - # LLM configs created by this user - new_llm_configs = relationship( - "NewLLMConfig", - back_populates="user", - passive_deletes=True, - ) - - # Image generation configs created by this user - image_generation_configs = relationship( - "ImageGenerationConfig", - back_populates="user", - passive_deletes=True, - ) - - vision_llm_configs = relationship( - "VisionLLMConfig", - back_populates="user", - passive_deletes=True, - ) connections = relationship( "Connection", back_populates="user", diff --git a/surfsense_backend/app/prompts/default_system_instructions.py b/surfsense_backend/app/prompts/default_system_instructions.py index fd0a8e186..b968fc1f0 100644 --- a/surfsense_backend/app/prompts/default_system_instructions.py +++ b/surfsense_backend/app/prompts/default_system_instructions.py @@ -82,7 +82,7 @@ def build_configurable_system_prompt( *, model_name: str | None = None, ) -> str: - """Build a configurable SurfSense system prompt (NewLLMConfig path). + """Build a configurable SurfSense system prompt. See :func:`app.prompts.system_prompt_composer.composer.compose_system_prompt` for full parameter docs. @@ -104,7 +104,7 @@ def build_configurable_system_prompt( def get_default_system_instructions() -> str: """Return the default ```` block (no tools / citations). - Useful for populating the UI when seeding ``NewLLMConfig.system_instructions``. + Useful for populating the UI when editing custom system instructions. The output reflects the current fragment tree, not a baked-in constant. """ resolved_today = datetime.now(UTC).date().isoformat() diff --git a/surfsense_backend/app/prompts/system_prompt_composer/composer.py b/surfsense_backend/app/prompts/system_prompt_composer/composer.py index 3849af313..c639d4aa0 100644 --- a/surfsense_backend/app/prompts/system_prompt_composer/composer.py +++ b/surfsense_backend/app/prompts/system_prompt_composer/composer.py @@ -348,8 +348,7 @@ def compose_system_prompt( mcp_connector_tools: ``{server_name: [tool_names...]}`` to inject an explicit MCP routing block. custom_system_instructions: Free-form instructions that override - the default ```` block (legacy support - for ``NewLLMConfig.system_instructions``). + the default ```` block. use_default_system_instructions: When ``custom_system_instructions`` is empty/None, fall back to defaults (legacy semantics). citations_enabled: Include ``citations_on.md`` (true) or diff --git a/surfsense_backend/app/routes/__init__.py b/surfsense_backend/app/routes/__init__.py index f9f6b3d28..2b997cef5 100644 --- a/surfsense_backend/app/routes/__init__.py +++ b/surfsense_backend/app/routes/__init__.py @@ -47,7 +47,6 @@ from .model_connections_routes import router as model_connections_router from .memory_routes import router as memory_router from .model_list_routes import router as model_list_router from .new_chat_routes import router as new_chat_router -from .new_llm_config_routes import router as new_llm_config_router from .notes_routes import router as notes_router from .notion_add_connector_route import router as notion_add_connector_router from .obsidian_plugin_routes import router as obsidian_plugin_router @@ -64,7 +63,6 @@ from .stripe_routes import router as stripe_router from .team_memory_routes import router as team_memory_router from .teams_add_connector_route import router as teams_add_connector_router from .video_presentations_routes import router as video_presentations_router -from .vision_llm_routes import router as vision_llm_router from .youtube_routes import router as youtube_router router = APIRouter() @@ -99,7 +97,6 @@ router.include_router( ) # Video presentation status and streaming router.include_router(reports_router) # Report CRUD and multi-format export router.include_router(image_generation_router) # Image generation via litellm -router.include_router(vision_llm_router) # Vision LLM configs for screenshot analysis router.include_router(search_source_connectors_router) router.include_router(google_calendar_add_connector_router) router.include_router(google_gmail_add_connector_router) @@ -117,7 +114,6 @@ router.include_router(jira_add_connector_router) router.include_router(confluence_add_connector_router) router.include_router(clickup_add_connector_router) router.include_router(dropbox_add_connector_router) -router.include_router(new_llm_config_router) # LLM configs with prompt configuration router.include_router(model_connections_router) # Connection-centric model catalog router.include_router(model_list_router) # Dynamic model catalogue from OpenRouter router.include_router(logs_router) diff --git a/surfsense_backend/app/routes/image_generation_routes.py b/surfsense_backend/app/routes/image_generation_routes.py index 29a2b58bc..7e95d4dba 100644 --- a/surfsense_backend/app/routes/image_generation_routes.py +++ b/surfsense_backend/app/routes/image_generation_routes.py @@ -1,7 +1,5 @@ """ Image Generation routes: -- CRUD for ImageGenerationConfig (user-created image model configs) -- Global image gen configs endpoint (from YAML) - Image generation execution (calls litellm.aimage_generation()) - CRUD for ImageGeneration records (results) - Image serving endpoint (serves b64_json images from DB, protected by signed tokens) @@ -21,7 +19,6 @@ from sqlalchemy.orm import selectinload from app.config import config from app.db import ( ImageGeneration, - ImageGenerationConfig, Model, Permission, SearchSpace, @@ -30,14 +27,14 @@ from app.db import ( get_async_session, ) from app.schemas import ( - GlobalImageGenConfigRead, - ImageGenerationConfigCreate, - ImageGenerationConfigRead, - ImageGenerationConfigUpdate, ImageGenerationCreate, ImageGenerationListRead, ImageGenerationRead, ) +from app.services.auto_model_pin_service import ( + auto_model_candidates, + choose_auto_model_candidate, +) from app.services.billable_calls import ( DEFAULT_IMAGE_RESERVE_MICROS, QuotaInsufficientError, @@ -47,12 +44,8 @@ from app.services.image_gen_router_service import ( IMAGE_GEN_AUTO_MODE_ID, is_image_gen_auto_mode, ) -from app.services.auto_model_pin_service import ( - auto_model_candidates, - choose_auto_model_candidate, -) -from app.services.model_resolver import to_litellm from app.services.model_capabilities import has_capability +from app.services.model_resolver import to_litellm from app.users import current_active_user from app.utils.rbac import check_permission from app.utils.signed_image_urls import verify_image_token @@ -131,14 +124,14 @@ async def _execute_image_generation( Call litellm.aimage_generation() with the appropriate config. Resolution order: - 1. Explicit image_generation_config_id on the request - 2. Search space's image_generation_config_id preference + 1. Explicit image_gen_model_id on the request + 2. Search space's image_gen_model_id preference 3. Falls back to Auto mode if available """ - config_id = image_gen.image_generation_config_id + config_id = image_gen.image_gen_model_id if config_id is None: config_id = search_space.image_gen_model_id or IMAGE_GEN_AUTO_MODE_ID - image_gen.image_generation_config_id = config_id + image_gen.image_gen_model_id = config_id # Build kwargs gen_kwargs = {} @@ -163,7 +156,7 @@ async def _execute_image_generation( if not candidates: raise ValueError("No image-generation models are available for Auto mode") config_id = int(choose_auto_model_candidate(candidates, search_space.id)["id"]) - image_gen.image_generation_config_id = config_id + image_gen.image_gen_model_id = config_id if config_id < 0: global_model = _get_global_model(config_id) @@ -228,266 +221,6 @@ async def _execute_image_generation( image_gen.model = hidden["model"] -# ============================================================================= -# Global Image Generation Configs (from YAML) -# ============================================================================= - - -@router.get( - "/global-image-generation-configs", - response_model=list[GlobalImageGenConfigRead], -) -async def get_global_image_gen_configs( - user: User = Depends(current_active_user), -): - """Get all global image generation configs. API keys are hidden.""" - try: - global_configs = config.GLOBAL_IMAGE_GEN_CONFIGS - safe_configs = [] - - if global_configs and len(global_configs) > 0: - safe_configs.append( - { - "id": 0, - "name": "Auto (Fastest)", - "description": "Automatically routes across available image generation providers.", - "provider": "AUTO", - "custom_provider": None, - "model_name": "auto", - "api_base": None, - "api_version": None, - "litellm_params": {}, - "is_global": True, - "is_auto_mode": True, - # Auto mode currently treated as free until per-deployment - # billing-tier surfacing lands (see _resolve_billing_for_image_gen). - "billing_tier": "free", - "is_premium": False, - } - ) - - for cfg in global_configs: - billing_tier = str(cfg.get("billing_tier", "free")).lower() - safe_configs.append( - { - "id": cfg.get("id"), - "name": cfg.get("name"), - "description": cfg.get("description"), - "provider": cfg.get("provider") or cfg.get("litellm_provider"), - "custom_provider": cfg.get("custom_provider"), - "model_name": cfg.get("model_name"), - "api_base": cfg.get("api_base") or None, - "api_version": cfg.get("api_version") or None, - "litellm_params": cfg.get("litellm_params", {}), - "is_global": True, - "billing_tier": billing_tier, - # Mirror chat (``new_llm_config_routes``) so the new-chat - # selector's premium badge logic keys off the same - # field across chat / image / vision tabs. - "is_premium": billing_tier == "premium", - "quota_reserve_micros": cfg.get("quota_reserve_micros"), - } - ) - - return safe_configs - except Exception as e: - logger.exception("Failed to fetch global image generation configs") - raise HTTPException( - status_code=500, detail=f"Failed to fetch configs: {e!s}" - ) from e - - -# ============================================================================= -# ImageGenerationConfig CRUD -# ============================================================================= - - -@router.post("/image-generation-configs", response_model=ImageGenerationConfigRead) -async def create_image_gen_config( - config_data: ImageGenerationConfigCreate, - session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), -): - """Create a new image generation config for a search space.""" - try: - await check_permission( - session, - user, - config_data.search_space_id, - Permission.IMAGE_GENERATIONS_CREATE.value, - "You don't have permission to create image generation configs in this search space", - ) - - db_config = ImageGenerationConfig(**config_data.model_dump(), user_id=user.id) - session.add(db_config) - await session.commit() - await session.refresh(db_config) - return db_config - - except HTTPException: - raise - except Exception as e: - await session.rollback() - logger.exception("Failed to create ImageGenerationConfig") - raise HTTPException( - status_code=500, detail=f"Failed to create config: {e!s}" - ) from e - - -@router.get("/image-generation-configs", response_model=list[ImageGenerationConfigRead]) -async def list_image_gen_configs( - search_space_id: int, - skip: int = 0, - limit: int = 100, - session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), -): - """List image generation configs for a search space.""" - try: - await check_permission( - session, - user, - search_space_id, - Permission.IMAGE_GENERATIONS_READ.value, - "You don't have permission to view image generation configs in this search space", - ) - - result = await session.execute( - select(ImageGenerationConfig) - .filter(ImageGenerationConfig.search_space_id == search_space_id) - .order_by(ImageGenerationConfig.created_at.desc()) - .offset(skip) - .limit(limit) - ) - return result.scalars().all() - - except HTTPException: - raise - except Exception as e: - logger.exception("Failed to list ImageGenerationConfigs") - raise HTTPException( - status_code=500, detail=f"Failed to fetch configs: {e!s}" - ) from e - - -@router.get( - "/image-generation-configs/{config_id}", response_model=ImageGenerationConfigRead -) -async def get_image_gen_config( - config_id: int, - session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), -): - """Get a specific image generation config by ID.""" - try: - result = await session.execute( - select(ImageGenerationConfig).filter(ImageGenerationConfig.id == config_id) - ) - db_config = result.scalars().first() - if not db_config: - raise HTTPException(status_code=404, detail="Config not found") - - await check_permission( - session, - user, - db_config.search_space_id, - Permission.IMAGE_GENERATIONS_READ.value, - "You don't have permission to view image generation configs in this search space", - ) - return db_config - - except HTTPException: - raise - except Exception as e: - logger.exception("Failed to get ImageGenerationConfig") - raise HTTPException( - status_code=500, detail=f"Failed to fetch config: {e!s}" - ) from e - - -@router.put( - "/image-generation-configs/{config_id}", response_model=ImageGenerationConfigRead -) -async def update_image_gen_config( - config_id: int, - update_data: ImageGenerationConfigUpdate, - session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), -): - """Update an existing image generation config.""" - try: - result = await session.execute( - select(ImageGenerationConfig).filter(ImageGenerationConfig.id == config_id) - ) - db_config = result.scalars().first() - if not db_config: - raise HTTPException(status_code=404, detail="Config not found") - - await check_permission( - session, - user, - db_config.search_space_id, - Permission.IMAGE_GENERATIONS_CREATE.value, - "You don't have permission to update image generation configs in this search space", - ) - - for key, value in update_data.model_dump(exclude_unset=True).items(): - setattr(db_config, key, value) - - await session.commit() - await session.refresh(db_config) - return db_config - - except HTTPException: - raise - except Exception as e: - await session.rollback() - logger.exception("Failed to update ImageGenerationConfig") - raise HTTPException( - status_code=500, detail=f"Failed to update config: {e!s}" - ) from e - - -@router.delete("/image-generation-configs/{config_id}", response_model=dict) -async def delete_image_gen_config( - config_id: int, - session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), -): - """Delete an image generation config.""" - try: - result = await session.execute( - select(ImageGenerationConfig).filter(ImageGenerationConfig.id == config_id) - ) - db_config = result.scalars().first() - if not db_config: - raise HTTPException(status_code=404, detail="Config not found") - - await check_permission( - session, - user, - db_config.search_space_id, - Permission.IMAGE_GENERATIONS_DELETE.value, - "You don't have permission to delete image generation configs in this search space", - ) - - await session.delete(db_config) - await session.commit() - return { - "message": "Image generation config deleted successfully", - "id": config_id, - } - - except HTTPException: - raise - except Exception as e: - await session.rollback() - logger.exception("Failed to delete ImageGenerationConfig") - raise HTTPException( - status_code=500, detail=f"Failed to delete config: {e!s}" - ) from e - - # ============================================================================= # Image Generation Execution + Results CRUD # ============================================================================= @@ -536,7 +269,7 @@ async def create_image_generation( raise HTTPException(status_code=404, detail="Search space not found") billing_tier, base_model, reserve_micros = await _resolve_billing_for_image_gen( - session, data.image_generation_config_id, search_space + session, data.image_gen_model_id, search_space ) # billable_call runs OUTSIDE the inner try/except so QuotaInsufficientError @@ -562,7 +295,7 @@ async def create_image_generation( size=data.size, style=data.style, response_format=data.response_format, - image_generation_config_id=data.image_generation_config_id, + image_gen_model_id=data.image_gen_model_id, search_space_id=data.search_space_id, created_by_id=user.id, ) diff --git a/surfsense_backend/app/routes/model_connections_routes.py b/surfsense_backend/app/routes/model_connections_routes.py index 1fd2e1e8e..90d246c54 100644 --- a/surfsense_backend/app/routes/model_connections_routes.py +++ b/surfsense_backend/app/routes/model_connections_routes.py @@ -11,6 +11,7 @@ from app.db import ( ConnectionScope, Model, ModelSource, + NewChatThread, Permission, SearchSpace, User, @@ -708,12 +709,26 @@ async def update_model_roles( search_space = await _get_search_space(session, search_space_id) updates = data.model_dump(exclude_unset=True) if "chat_model_id" in updates: - search_space.chat_model_id = await _validate_role_model_id( + previous_chat_model_id = search_space.chat_model_id + next_chat_model_id = await _validate_role_model_id( session, search_space_id=search_space_id, model_id=updates["chat_model_id"], capability="chat", ) + search_space.chat_model_id = next_chat_model_id + if next_chat_model_id != previous_chat_model_id: + await session.execute( + update(NewChatThread) + .where(NewChatThread.search_space_id == search_space_id) + .values(pinned_llm_config_id=None) + ) + logger.info( + "Cleared auto model pins for search_space_id=%s after chat_model_id change (%s -> %s)", + search_space_id, + previous_chat_model_id, + next_chat_model_id, + ) if "vision_model_id" in updates: search_space.vision_model_id = await _validate_role_model_id( session, diff --git a/surfsense_backend/app/routes/new_llm_config_routes.py b/surfsense_backend/app/routes/new_llm_config_routes.py deleted file mode 100644 index adba5b5ae..000000000 --- a/surfsense_backend/app/routes/new_llm_config_routes.py +++ /dev/null @@ -1,480 +0,0 @@ -""" -API routes for NewLLMConfig CRUD operations. - -NewLLMConfig combines model settings with prompt configuration: -- LLM provider, model, API key, etc. -- Configurable system instructions -- Citation toggle -""" - -import logging - -from fastapi import APIRouter, Depends, HTTPException -from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.future import select - -from app.config import config -from app.db import ( - NewLLMConfig, - Permission, - User, - get_async_session, -) -from app.prompts.default_system_instructions import get_default_system_instructions -from app.schemas import ( - DefaultSystemInstructionsResponse, - GlobalNewLLMConfigRead, - NewLLMConfigCreate, - NewLLMConfigRead, - NewLLMConfigUpdate, -) -from app.services.llm_service import validate_llm_config -from app.services.provider_capabilities import derive_supports_image_input -from app.users import current_active_user -from app.utils.rbac import check_permission - -router = APIRouter() -logger = logging.getLogger(__name__) - - -def _serialize_byok_config(config: NewLLMConfig) -> NewLLMConfigRead: - """Augment a BYOK chat config row with the derived ``supports_image_input``. - - There is no DB column for ``supports_image_input`` — the value is - resolved at the API boundary from LiteLLM's authoritative model map - (default-allow on unknown). Returning ``NewLLMConfigRead`` here keeps - the response shape consistent across list / detail / create / update - endpoints without having to remember to set the field at every call - site. - """ - provider_value = ( - config.provider.value - if hasattr(config.provider, "value") - else str(config.provider) - ) - litellm_params = config.litellm_params or {} - base_model = ( - litellm_params.get("base_model") if isinstance(litellm_params, dict) else None - ) - supports_image_input = derive_supports_image_input( - provider=provider_value.lower(), - model_name=config.model_name, - base_model=base_model, - custom_provider=config.custom_provider, - ) - # ``model_validate`` runs the Pydantic conversion using the ORM - # attribute access path enabled by ``ConfigDict(from_attributes=True)``, - # then we layer the derived field on. ``model_copy(update=...)`` keeps - # the surface immutable from the caller's perspective. - base_read = NewLLMConfigRead.model_validate(config) - return base_read.model_copy(update={"supports_image_input": supports_image_input}) - - -# ============================================================================= -# Global Configs Routes -# ============================================================================= - - -@router.get("/global-new-llm-configs", response_model=list[GlobalNewLLMConfigRead]) -async def get_global_new_llm_configs( - user: User = Depends(current_active_user), -): - """ - Get all available global NewLLMConfig configurations. - These are pre-configured by the system administrator and available to all users. - API keys are not exposed through this endpoint. - - Includes: - - Auto mode (ID 0): Uses LiteLLM Router for automatic load balancing - - Global configs (negative IDs): Individual pre-configured LLM providers - """ - try: - global_configs = config.GLOBAL_LLM_CONFIGS - safe_configs = [] - - # Only include Auto mode if there are actual global configs to route to - # Auto mode requires at least one global config with valid API key - if global_configs and len(global_configs) > 0: - safe_configs.append( - { - "id": 0, - "name": "Auto (Fastest)", - "description": "Automatically routes requests across available LLM providers for optimal performance and rate limit handling. Recommended for most users.", - "provider": "AUTO", - "custom_provider": None, - "model_name": "auto", - "api_base": None, - "litellm_params": {}, - "system_instructions": "", - "use_default_system_instructions": True, - "citations_enabled": True, - "is_global": True, - "is_auto_mode": True, - "billing_tier": "free", - "is_premium": False, - "anonymous_enabled": False, - "seo_enabled": False, - "seo_slug": None, - "seo_title": None, - "seo_description": None, - "quota_reserve_tokens": None, - # Auto routes across the configured pool, which usually - # includes at least one vision-capable deployment, so - # treat Auto as image-capable. The router itself will - # still pick a vision-capable deployment for messages - # carrying image_url blocks (LiteLLM Router falls back - # on ``404`` per its ``allowed_fails`` policy). - "supports_image_input": True, - } - ) - - # Add individual global configs - for cfg in global_configs: - # Capability resolution: explicit value (YAML override or OR - # `_supports_image_input(model)` payload baked in by the - # OpenRouter integration service) wins. Fall back to the - # LiteLLM-driven helper which default-allows on unknown so - # we don't hide vision-capable models that happen to lack a - # YAML annotation. The streaming task safety net is the - # only place a False ever blocks. - if "supports_image_input" in cfg: - supports_image_input = bool(cfg.get("supports_image_input")) - else: - cfg_litellm_params = cfg.get("litellm_params") or {} - cfg_base_model = ( - cfg_litellm_params.get("base_model") - if isinstance(cfg_litellm_params, dict) - else None - ) - supports_image_input = derive_supports_image_input( - provider=cfg.get("provider") or cfg.get("litellm_provider"), - model_name=cfg.get("model_name"), - base_model=cfg_base_model, - custom_provider=cfg.get("custom_provider"), - ) - - safe_config = { - "id": cfg.get("id"), - "name": cfg.get("name"), - "description": cfg.get("description"), - "provider": cfg.get("provider") or cfg.get("litellm_provider"), - "custom_provider": cfg.get("custom_provider"), - "model_name": cfg.get("model_name"), - "api_base": cfg.get("api_base") or None, - "litellm_params": cfg.get("litellm_params", {}), - # New prompt configuration fields - "system_instructions": cfg.get("system_instructions", ""), - "use_default_system_instructions": cfg.get( - "use_default_system_instructions", True - ), - "citations_enabled": cfg.get("citations_enabled", True), - "is_global": True, - "billing_tier": cfg.get("billing_tier", "free"), - "is_premium": cfg.get("billing_tier", "free") == "premium", - "anonymous_enabled": cfg.get("anonymous_enabled", False), - "seo_enabled": cfg.get("seo_enabled", False), - "seo_slug": cfg.get("seo_slug"), - "seo_title": cfg.get("seo_title"), - "seo_description": cfg.get("seo_description"), - "quota_reserve_tokens": cfg.get("quota_reserve_tokens"), - "supports_image_input": supports_image_input, - } - safe_configs.append(safe_config) - - return safe_configs - except Exception as e: - logger.exception("Failed to fetch global NewLLMConfigs") - raise HTTPException( - status_code=500, detail=f"Failed to fetch global configurations: {e!s}" - ) from e - - -# ============================================================================= -# CRUD Routes -# ============================================================================= - - -@router.post("/new-llm-configs", response_model=NewLLMConfigRead) -async def create_new_llm_config( - config_data: NewLLMConfigCreate, - session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), -): - """ - Create a new NewLLMConfig for a search space. - Requires LLM_CONFIGS_CREATE permission. - """ - try: - # Verify user has permission - await check_permission( - session, - user, - config_data.search_space_id, - Permission.LLM_CONFIGS_CREATE.value, - "You don't have permission to create LLM configurations in this search space", - ) - - # Validate the LLM configuration by making a test API call - is_valid, error_message = await validate_llm_config( - provider=config_data.provider.value, - model_name=config_data.model_name, - api_key=config_data.api_key, - api_base=config_data.api_base, - custom_provider=config_data.custom_provider, - litellm_params=config_data.litellm_params, - ) - - if not is_valid: - raise HTTPException( - status_code=400, - detail=f"Invalid LLM configuration: {error_message}", - ) - - # Create the config with user association - db_config = NewLLMConfig(**config_data.model_dump(), user_id=user.id) - session.add(db_config) - await session.commit() - await session.refresh(db_config) - - return _serialize_byok_config(db_config) - - except HTTPException: - raise - except Exception as e: - await session.rollback() - logger.exception("Failed to create NewLLMConfig") - raise HTTPException( - status_code=500, detail=f"Failed to create configuration: {e!s}" - ) from e - - -@router.get("/new-llm-configs", response_model=list[NewLLMConfigRead]) -async def list_new_llm_configs( - search_space_id: int, - skip: int = 0, - limit: int = 100, - session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), -): - """ - Get all NewLLMConfigs for a search space. - Requires LLM_CONFIGS_READ permission. - """ - try: - # Verify user has permission - await check_permission( - session, - user, - search_space_id, - Permission.LLM_CONFIGS_READ.value, - "You don't have permission to view LLM configurations in this search space", - ) - - result = await session.execute( - select(NewLLMConfig) - .filter(NewLLMConfig.search_space_id == search_space_id) - .order_by(NewLLMConfig.created_at.desc()) - .offset(skip) - .limit(limit) - ) - - return [_serialize_byok_config(cfg) for cfg in result.scalars().all()] - - except HTTPException: - raise - except Exception as e: - logger.exception("Failed to list NewLLMConfigs") - raise HTTPException( - status_code=500, detail=f"Failed to fetch configurations: {e!s}" - ) from e - - -@router.get( - "/new-llm-configs/default-system-instructions", - response_model=DefaultSystemInstructionsResponse, -) -async def get_default_system_instructions_endpoint( - user: User = Depends(current_active_user), -): - """ - Get the default SURFSENSE_SYSTEM_INSTRUCTIONS template. - Useful for pre-populating the UI when creating a new configuration. - """ - return DefaultSystemInstructionsResponse( - default_system_instructions=get_default_system_instructions() - ) - - -@router.get("/new-llm-configs/{config_id}", response_model=NewLLMConfigRead) -async def get_new_llm_config( - config_id: int, - session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), -): - """ - Get a specific NewLLMConfig by ID. - Requires LLM_CONFIGS_READ permission. - """ - try: - result = await session.execute( - select(NewLLMConfig).filter(NewLLMConfig.id == config_id) - ) - config = result.scalars().first() - - if not config: - raise HTTPException(status_code=404, detail="Configuration not found") - - # Verify user has permission - await check_permission( - session, - user, - config.search_space_id, - Permission.LLM_CONFIGS_READ.value, - "You don't have permission to view LLM configurations in this search space", - ) - - return _serialize_byok_config(config) - - except HTTPException: - raise - except Exception as e: - logger.exception("Failed to get NewLLMConfig") - raise HTTPException( - status_code=500, detail=f"Failed to fetch configuration: {e!s}" - ) from e - - -@router.put("/new-llm-configs/{config_id}", response_model=NewLLMConfigRead) -async def update_new_llm_config( - config_id: int, - update_data: NewLLMConfigUpdate, - session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), -): - """ - Update an existing NewLLMConfig. - Requires LLM_CONFIGS_UPDATE permission. - """ - try: - result = await session.execute( - select(NewLLMConfig).filter(NewLLMConfig.id == config_id) - ) - config = result.scalars().first() - - if not config: - raise HTTPException(status_code=404, detail="Configuration not found") - - # Verify user has permission - await check_permission( - session, - user, - config.search_space_id, - Permission.LLM_CONFIGS_UPDATE.value, - "You don't have permission to update LLM configurations in this search space", - ) - - update_dict = update_data.model_dump(exclude_unset=True) - - # If updating LLM settings, validate them - if any( - key in update_dict - for key in [ - "provider", - "model_name", - "api_key", - "api_base", - "custom_provider", - "litellm_params", - ] - ): - # Build the validation config from existing + updates - validation_config = { - "provider": update_dict.get("provider", config.provider).value - if hasattr(update_dict.get("provider", config.provider), "value") - else update_dict.get("provider", config.provider.value), - "model_name": update_dict.get("model_name", config.model_name), - "api_key": update_dict.get("api_key", config.api_key), - "api_base": update_dict.get("api_base", config.api_base), - "custom_provider": update_dict.get( - "custom_provider", config.custom_provider - ), - "litellm_params": update_dict.get( - "litellm_params", config.litellm_params - ), - } - - is_valid, error_message = await validate_llm_config( - provider=validation_config["provider"], - model_name=validation_config["model_name"], - api_key=validation_config["api_key"], - api_base=validation_config["api_base"], - custom_provider=validation_config["custom_provider"], - litellm_params=validation_config["litellm_params"], - ) - - if not is_valid: - raise HTTPException( - status_code=400, - detail=f"Invalid LLM configuration: {error_message}", - ) - - # Apply updates - for key, value in update_dict.items(): - setattr(config, key, value) - - await session.commit() - await session.refresh(config) - - return _serialize_byok_config(config) - - except HTTPException: - raise - except Exception as e: - await session.rollback() - logger.exception("Failed to update NewLLMConfig") - raise HTTPException( - status_code=500, detail=f"Failed to update configuration: {e!s}" - ) from e - - -@router.delete("/new-llm-configs/{config_id}", response_model=dict) -async def delete_new_llm_config( - config_id: int, - session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), -): - """ - Delete a NewLLMConfig. - Requires LLM_CONFIGS_DELETE permission. - """ - try: - result = await session.execute( - select(NewLLMConfig).filter(NewLLMConfig.id == config_id) - ) - config = result.scalars().first() - - if not config: - raise HTTPException(status_code=404, detail="Configuration not found") - - # Verify user has permission - await check_permission( - session, - user, - config.search_space_id, - Permission.LLM_CONFIGS_DELETE.value, - "You don't have permission to delete LLM configurations in this search space", - ) - - await session.delete(config) - await session.commit() - - return {"message": "Configuration deleted successfully", "id": config_id} - - except HTTPException: - raise - except Exception as e: - await session.rollback() - logger.exception("Failed to delete NewLLMConfig") - raise HTTPException( - status_code=500, detail=f"Failed to delete configuration: {e!s}" - ) from e diff --git a/surfsense_backend/app/routes/search_spaces_routes.py b/surfsense_backend/app/routes/search_spaces_routes.py index 7c5fbf28b..592a9dd0e 100644 --- a/surfsense_backend/app/routes/search_spaces_routes.py +++ b/surfsense_backend/app/routes/search_spaces_routes.py @@ -1,27 +1,20 @@ import logging from fastapi import APIRouter, Depends, HTTPException -from sqlalchemy import func, update +from sqlalchemy import func from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.future import select -from app.config import config from app.db import ( - ImageGenerationConfig, - NewChatThread, - NewLLMConfig, Permission, SearchSpace, SearchSpaceMembership, SearchSpaceRole, User, - VisionLLMConfig, get_async_session, get_default_roles_config, ) from app.schemas import ( - LLMPreferencesRead, - LLMPreferencesUpdate, SearchSpaceCreate, SearchSpaceRead, SearchSpaceUpdate, @@ -377,357 +370,6 @@ async def delete_search_space( ) from e -# ============================================================================= -# LLM Preferences Routes -# ============================================================================= - - -async def _get_llm_config_by_id( - session: AsyncSession, config_id: int | None -) -> dict | None: - """ - Get an LLM config by ID as a dictionary. Returns database config for positive IDs, - global config for negative IDs, Auto mode config for ID 0, or None if ID is None. - """ - if config_id is None: - return None - - # Auto mode (ID 0) - uses LiteLLM Router for load balancing - if config_id == 0: - return { - "id": 0, - "name": "Auto (Fastest)", - "description": "Automatically routes requests across available LLM providers for optimal performance and rate limit handling", - "provider": "AUTO", - "custom_provider": None, - "model_name": "auto", - "api_base": None, - "litellm_params": {}, - "system_instructions": "", - "use_default_system_instructions": True, - "citations_enabled": True, - "is_global": True, - "is_auto_mode": True, - } - - if config_id < 0: - # Global config - find from YAML - global_configs = config.GLOBAL_LLM_CONFIGS - for cfg in global_configs: - if cfg.get("id") == config_id: - return { - "id": cfg.get("id"), - "name": cfg.get("name"), - "description": cfg.get("description"), - "provider": cfg.get("provider") or cfg.get("litellm_provider"), - "custom_provider": cfg.get("custom_provider"), - "model_name": cfg.get("model_name"), - "api_base": cfg.get("api_base"), - "litellm_params": cfg.get("litellm_params", {}), - "system_instructions": cfg.get("system_instructions", ""), - "use_default_system_instructions": cfg.get( - "use_default_system_instructions", True - ), - "citations_enabled": cfg.get("citations_enabled", True), - "is_global": True, - } - return None - else: - # Database config - convert to dict - result = await session.execute( - select(NewLLMConfig).filter(NewLLMConfig.id == config_id) - ) - db_config = result.scalars().first() - if db_config: - return { - "id": db_config.id, - "name": db_config.name, - "description": db_config.description, - "provider": db_config.provider.value if db_config.provider else None, - "custom_provider": db_config.custom_provider, - "model_name": db_config.model_name, - "api_key": db_config.api_key, - "api_base": db_config.api_base, - "litellm_params": db_config.litellm_params or {}, - "system_instructions": db_config.system_instructions or "", - "use_default_system_instructions": db_config.use_default_system_instructions, - "citations_enabled": db_config.citations_enabled, - "created_at": db_config.created_at.isoformat() - if db_config.created_at - else None, - "search_space_id": db_config.search_space_id, - } - return None - - -async def _get_image_gen_config_by_id( - session: AsyncSession, config_id: int | None -) -> dict | None: - """ - Get an image generation config by ID as a dictionary. - Returns Auto mode for ID 0, global config for negative IDs, - DB ImageGenerationConfig for positive IDs, or None. - """ - if config_id is None: - return None - - if config_id == 0: - return { - "id": 0, - "name": "Auto (Fastest)", - "description": "Automatically routes requests across available image generation providers", - "provider": "AUTO", - "model_name": "auto", - "is_global": True, - "is_auto_mode": True, - "billing_tier": "free", - } - - if config_id < 0: - for cfg in config.GLOBAL_IMAGE_GEN_CONFIGS: - if cfg.get("id") == config_id: - return { - "id": cfg.get("id"), - "name": cfg.get("name"), - "description": cfg.get("description"), - "provider": cfg.get("provider") or cfg.get("litellm_provider"), - "custom_provider": cfg.get("custom_provider"), - "model_name": cfg.get("model_name"), - "api_base": cfg.get("api_base") or None, - "api_version": cfg.get("api_version") or None, - "litellm_params": cfg.get("litellm_params", {}), - "is_global": True, - "billing_tier": cfg.get("billing_tier", "free"), - } - return None - - # Positive ID: query ImageGenerationConfig table - result = await session.execute( - select(ImageGenerationConfig).filter(ImageGenerationConfig.id == config_id) - ) - db_config = result.scalars().first() - if db_config: - return { - "id": db_config.id, - "name": db_config.name, - "description": db_config.description, - "provider": db_config.provider.value if db_config.provider else None, - "custom_provider": db_config.custom_provider, - "model_name": db_config.model_name, - "api_base": db_config.api_base, - "api_version": db_config.api_version, - "litellm_params": db_config.litellm_params or {}, - "created_at": db_config.created_at.isoformat() - if db_config.created_at - else None, - "search_space_id": db_config.search_space_id, - } - return None - - -async def _get_vision_llm_config_by_id( - session: AsyncSession, config_id: int | None -) -> dict | None: - if config_id is None: - return None - - if config_id == 0: - return { - "id": 0, - "name": "Auto (Fastest)", - "description": "Automatically routes requests across available vision LLM providers", - "provider": "AUTO", - "model_name": "auto", - "is_global": True, - "is_auto_mode": True, - "billing_tier": "free", - } - - if config_id < 0: - for cfg in config.GLOBAL_VISION_LLM_CONFIGS: - if cfg.get("id") == config_id: - return { - "id": cfg.get("id"), - "name": cfg.get("name"), - "description": cfg.get("description"), - "provider": cfg.get("provider") or cfg.get("litellm_provider"), - "custom_provider": cfg.get("custom_provider"), - "model_name": cfg.get("model_name"), - "api_base": cfg.get("api_base") or None, - "api_version": cfg.get("api_version") or None, - "litellm_params": cfg.get("litellm_params", {}), - "is_global": True, - "billing_tier": cfg.get("billing_tier", "free"), - } - return None - - result = await session.execute( - select(VisionLLMConfig).filter(VisionLLMConfig.id == config_id) - ) - db_config = result.scalars().first() - if db_config: - return { - "id": db_config.id, - "name": db_config.name, - "description": db_config.description, - "provider": db_config.provider.value if db_config.provider else None, - "custom_provider": db_config.custom_provider, - "model_name": db_config.model_name, - "api_base": db_config.api_base, - "api_version": db_config.api_version, - "litellm_params": db_config.litellm_params or {}, - "created_at": db_config.created_at.isoformat() - if db_config.created_at - else None, - "search_space_id": db_config.search_space_id, - } - return None - - -@router.get( - "/search-spaces/{search_space_id}/llm-preferences", - response_model=LLMPreferencesRead, -) -async def get_llm_preferences( - search_space_id: int, - session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), -): - """ - Get LLM preferences (role assignments) for a search space. - Requires LLM_CONFIGS_READ permission. - """ - try: - # Check permission - await check_permission( - session, - user, - search_space_id, - Permission.LLM_CONFIGS_READ.value, - "You don't have permission to view LLM preferences", - ) - - result = await session.execute( - select(SearchSpace).filter(SearchSpace.id == search_space_id) - ) - search_space = result.scalars().first() - - if not search_space: - raise HTTPException(status_code=404, detail="Search space not found") - - # Get full config objects for each role - agent_llm = await _get_llm_config_by_id(session, search_space.agent_llm_id) - image_generation_config = await _get_image_gen_config_by_id( - session, search_space.image_generation_config_id - ) - vision_llm_config = await _get_vision_llm_config_by_id( - session, search_space.vision_llm_config_id - ) - - return LLMPreferencesRead( - agent_llm_id=search_space.agent_llm_id, - image_generation_config_id=search_space.image_generation_config_id, - vision_llm_config_id=search_space.vision_llm_config_id, - agent_llm=agent_llm, - image_generation_config=image_generation_config, - vision_llm_config=vision_llm_config, - ) - - except HTTPException: - raise - except Exception as e: - logger.exception("Failed to get LLM preferences") - raise HTTPException( - status_code=500, detail=f"Failed to get LLM preferences: {e!s}" - ) from e - - -@router.put( - "/search-spaces/{search_space_id}/llm-preferences", - response_model=LLMPreferencesRead, -) -async def update_llm_preferences( - search_space_id: int, - preferences: LLMPreferencesUpdate, - session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), -): - """ - Update LLM preferences (role assignments) for a search space. - Requires LLM_CONFIGS_UPDATE permission. - """ - try: - # Check permission - await check_permission( - session, - user, - search_space_id, - Permission.LLM_CONFIGS_UPDATE.value, - "You don't have permission to update LLM preferences", - ) - - result = await session.execute( - select(SearchSpace).filter(SearchSpace.id == search_space_id) - ) - search_space = result.scalars().first() - - if not search_space: - raise HTTPException(status_code=404, detail="Search space not found") - - # Update preferences - update_data = preferences.model_dump(exclude_unset=True) - previous_agent_llm_id = search_space.agent_llm_id - for key, value in update_data.items(): - setattr(search_space, key, value) - - agent_llm_changed = ( - "agent_llm_id" in update_data - and update_data["agent_llm_id"] != previous_agent_llm_id - ) - if agent_llm_changed: - await session.execute( - update(NewChatThread) - .where(NewChatThread.search_space_id == search_space_id) - .values(pinned_llm_config_id=None) - ) - logger.info( - "Cleared auto model pins for search_space_id=%s after agent_llm_id change (%s -> %s)", - search_space_id, - previous_agent_llm_id, - update_data["agent_llm_id"], - ) - - await session.commit() - await session.refresh(search_space) - - # Get full config objects for response - agent_llm = await _get_llm_config_by_id(session, search_space.agent_llm_id) - image_generation_config = await _get_image_gen_config_by_id( - session, search_space.image_generation_config_id - ) - vision_llm_config = await _get_vision_llm_config_by_id( - session, search_space.vision_llm_config_id - ) - - return LLMPreferencesRead( - agent_llm_id=search_space.agent_llm_id, - image_generation_config_id=search_space.image_generation_config_id, - vision_llm_config_id=search_space.vision_llm_config_id, - agent_llm=agent_llm, - image_generation_config=image_generation_config, - vision_llm_config=vision_llm_config, - ) - - except HTTPException: - raise - except Exception as e: - await session.rollback() - logger.exception("Failed to update LLM preferences") - raise HTTPException( - status_code=500, detail=f"Failed to update LLM preferences: {e!s}" - ) from e - - @router.get("/searchspaces/{search_space_id}/snapshots") async def list_search_space_snapshots( search_space_id: int, diff --git a/surfsense_backend/app/routes/vision_llm_routes.py b/surfsense_backend/app/routes/vision_llm_routes.py deleted file mode 100644 index b93d25b9c..000000000 --- a/surfsense_backend/app/routes/vision_llm_routes.py +++ /dev/null @@ -1,304 +0,0 @@ -import logging - -from fastapi import APIRouter, Depends, HTTPException -from pydantic import BaseModel -from sqlalchemy import select -from sqlalchemy.ext.asyncio import AsyncSession - -from app.config import config -from app.db import ( - Permission, - User, - VisionLLMConfig, - get_async_session, -) -from app.schemas import ( - GlobalVisionLLMConfigRead, - VisionLLMConfigCreate, - VisionLLMConfigRead, - VisionLLMConfigUpdate, -) -from app.services.vision_model_list_service import get_vision_model_list -from app.users import current_active_user -from app.utils.rbac import check_permission - -router = APIRouter() -logger = logging.getLogger(__name__) - - -# ============================================================================= -# Vision Model Catalogue (from OpenRouter, filtered for image-input models) -# ============================================================================= - - -class VisionModelListItem(BaseModel): - value: str - label: str - provider: str - context_window: str | None = None - - -@router.get("/vision-models", response_model=list[VisionModelListItem]) -async def list_vision_models( - user: User = Depends(current_active_user), -): - """Return vision-capable models sourced from OpenRouter (filtered by image input).""" - try: - return await get_vision_model_list() - except Exception as e: - logger.exception("Failed to fetch vision model list") - raise HTTPException( - status_code=500, detail=f"Failed to fetch vision model list: {e!s}" - ) from e - - -# ============================================================================= -# Global Vision LLM Configs (from YAML) -# ============================================================================= - - -@router.get( - "/global-vision-llm-configs", - response_model=list[GlobalVisionLLMConfigRead], -) -async def get_global_vision_llm_configs( - user: User = Depends(current_active_user), -): - try: - global_configs = config.GLOBAL_VISION_LLM_CONFIGS - safe_configs = [] - - if global_configs and len(global_configs) > 0: - safe_configs.append( - { - "id": 0, - "name": "Auto (Fastest)", - "description": "Automatically routes across available vision LLM providers.", - "provider": "AUTO", - "custom_provider": None, - "model_name": "auto", - "api_base": None, - "api_version": None, - "litellm_params": {}, - "is_global": True, - "is_auto_mode": True, - # Auto mode treated as free until per-deployment billing-tier - # surfacing lands; see ``get_vision_llm`` for parity. - "billing_tier": "free", - "is_premium": False, - } - ) - - for cfg in global_configs: - billing_tier = str(cfg.get("billing_tier", "free")).lower() - safe_configs.append( - { - "id": cfg.get("id"), - "name": cfg.get("name"), - "description": cfg.get("description"), - "provider": cfg.get("provider") or cfg.get("litellm_provider"), - "custom_provider": cfg.get("custom_provider"), - "model_name": cfg.get("model_name"), - "api_base": cfg.get("api_base") or None, - "api_version": cfg.get("api_version") or None, - "litellm_params": cfg.get("litellm_params", {}), - "is_global": True, - "billing_tier": billing_tier, - # Mirror chat (``new_llm_config_routes``) so the new-chat - # selector's premium badge logic keys off the same - # field across chat / image / vision tabs. - "is_premium": billing_tier == "premium", - "quota_reserve_tokens": cfg.get("quota_reserve_tokens"), - "input_cost_per_token": cfg.get("input_cost_per_token"), - "output_cost_per_token": cfg.get("output_cost_per_token"), - } - ) - - return safe_configs - except Exception as e: - logger.exception("Failed to fetch global vision LLM configs") - raise HTTPException( - status_code=500, detail=f"Failed to fetch configs: {e!s}" - ) from e - - -# ============================================================================= -# VisionLLMConfig CRUD -# ============================================================================= - - -@router.post("/vision-llm-configs", response_model=VisionLLMConfigRead) -async def create_vision_llm_config( - config_data: VisionLLMConfigCreate, - session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), -): - try: - await check_permission( - session, - user, - config_data.search_space_id, - Permission.VISION_CONFIGS_CREATE.value, - "You don't have permission to create vision LLM configs in this search space", - ) - - db_config = VisionLLMConfig(**config_data.model_dump(), user_id=user.id) - session.add(db_config) - await session.commit() - await session.refresh(db_config) - return db_config - - except HTTPException: - raise - except Exception as e: - await session.rollback() - logger.exception("Failed to create VisionLLMConfig") - raise HTTPException( - status_code=500, detail=f"Failed to create config: {e!s}" - ) from e - - -@router.get("/vision-llm-configs", response_model=list[VisionLLMConfigRead]) -async def list_vision_llm_configs( - search_space_id: int, - skip: int = 0, - limit: int = 100, - session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), -): - try: - await check_permission( - session, - user, - search_space_id, - Permission.VISION_CONFIGS_READ.value, - "You don't have permission to view vision LLM configs in this search space", - ) - - result = await session.execute( - select(VisionLLMConfig) - .filter(VisionLLMConfig.search_space_id == search_space_id) - .order_by(VisionLLMConfig.created_at.desc()) - .offset(skip) - .limit(limit) - ) - return result.scalars().all() - - except HTTPException: - raise - except Exception as e: - logger.exception("Failed to list VisionLLMConfigs") - raise HTTPException( - status_code=500, detail=f"Failed to fetch configs: {e!s}" - ) from e - - -@router.get("/vision-llm-configs/{config_id}", response_model=VisionLLMConfigRead) -async def get_vision_llm_config( - config_id: int, - session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), -): - try: - result = await session.execute( - select(VisionLLMConfig).filter(VisionLLMConfig.id == config_id) - ) - db_config = result.scalars().first() - if not db_config: - raise HTTPException(status_code=404, detail="Config not found") - - await check_permission( - session, - user, - db_config.search_space_id, - Permission.VISION_CONFIGS_READ.value, - "You don't have permission to view vision LLM configs in this search space", - ) - return db_config - - except HTTPException: - raise - except Exception as e: - logger.exception("Failed to get VisionLLMConfig") - raise HTTPException( - status_code=500, detail=f"Failed to fetch config: {e!s}" - ) from e - - -@router.put("/vision-llm-configs/{config_id}", response_model=VisionLLMConfigRead) -async def update_vision_llm_config( - config_id: int, - update_data: VisionLLMConfigUpdate, - session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), -): - try: - result = await session.execute( - select(VisionLLMConfig).filter(VisionLLMConfig.id == config_id) - ) - db_config = result.scalars().first() - if not db_config: - raise HTTPException(status_code=404, detail="Config not found") - - await check_permission( - session, - user, - db_config.search_space_id, - Permission.VISION_CONFIGS_CREATE.value, - "You don't have permission to update vision LLM configs in this search space", - ) - - for key, value in update_data.model_dump(exclude_unset=True).items(): - setattr(db_config, key, value) - - await session.commit() - await session.refresh(db_config) - return db_config - - except HTTPException: - raise - except Exception as e: - await session.rollback() - logger.exception("Failed to update VisionLLMConfig") - raise HTTPException( - status_code=500, detail=f"Failed to update config: {e!s}" - ) from e - - -@router.delete("/vision-llm-configs/{config_id}", response_model=dict) -async def delete_vision_llm_config( - config_id: int, - session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), -): - try: - result = await session.execute( - select(VisionLLMConfig).filter(VisionLLMConfig.id == config_id) - ) - db_config = result.scalars().first() - if not db_config: - raise HTTPException(status_code=404, detail="Config not found") - - await check_permission( - session, - user, - db_config.search_space_id, - Permission.VISION_CONFIGS_DELETE.value, - "You don't have permission to delete vision LLM configs in this search space", - ) - - await session.delete(db_config) - await session.commit() - return { - "message": "Vision LLM config deleted successfully", - "id": config_id, - } - - except HTTPException: - raise - except Exception as e: - await session.rollback() - logger.exception("Failed to delete VisionLLMConfig") - raise HTTPException( - status_code=500, detail=f"Failed to delete config: {e!s}" - ) from e diff --git a/surfsense_backend/app/schemas/__init__.py b/surfsense_backend/app/schemas/__init__.py index 3c4fdfa83..f577397b6 100644 --- a/surfsense_backend/app/schemas/__init__.py +++ b/surfsense_backend/app/schemas/__init__.py @@ -34,11 +34,6 @@ from .folders import ( ) from .google_drive import DriveItem, GoogleDriveIndexingOptions, GoogleDriveIndexRequest from .image_generation import ( - GlobalImageGenConfigRead, - ImageGenerationConfigCreate, - ImageGenerationConfigPublic, - ImageGenerationConfigRead, - ImageGenerationConfigUpdate, ImageGenerationCreate, ImageGenerationListRead, ImageGenerationRead, @@ -74,16 +69,6 @@ from .new_chat import ( ThreadListItem, ThreadListResponse, ) -from .new_llm_config import ( - DefaultSystemInstructionsResponse, - GlobalNewLLMConfigRead, - LLMPreferencesRead, - LLMPreferencesUpdate, - NewLLMConfigCreate, - NewLLMConfigPublic, - NewLLMConfigRead, - NewLLMConfigUpdate, -) from .rbac_schemas import ( InviteAcceptRequest, InviteAcceptResponse, @@ -142,14 +127,6 @@ from .video_presentations import ( VideoPresentationRead, VideoPresentationUpdate, ) -from .vision_llm import ( - GlobalVisionLLMConfigRead, - VisionLLMConfigCreate, - VisionLLMConfigPublic, - VisionLLMConfigRead, - VisionLLMConfigUpdate, -) - __all__ = [ # Folder schemas "BulkDocumentMove", @@ -169,7 +146,6 @@ __all__ = [ "CreditPurchaseHistoryResponse", "CreditPurchaseRead", "CreditStripeStatusResponse", - "DefaultSystemInstructionsResponse", # Document schemas "DocumentBase", "DocumentMove", @@ -192,19 +168,10 @@ __all__ = [ "FolderRead", "FolderReorder", "FolderUpdate", - "GlobalImageGenConfigRead", - "GlobalNewLLMConfigRead", - # Vision LLM Config schemas - "GlobalVisionLLMConfigRead", "GoogleDriveIndexRequest", "GoogleDriveIndexingOptions", # Base schemas "IDModel", - # Image Generation Config schemas - "ImageGenerationConfigCreate", - "ImageGenerationConfigPublic", - "ImageGenerationConfigRead", - "ImageGenerationConfigUpdate", # Image Generation schemas "ImageGenerationCreate", "ImageGenerationListRead", @@ -216,9 +183,6 @@ __all__ = [ "InviteInfoResponse", "InviteRead", "InviteUpdate", - # LLM Preferences schemas - "LLMPreferencesRead", - "LLMPreferencesUpdate", # Log schemas "LogBase", "LogCreate", @@ -255,11 +219,6 @@ __all__ = [ "NewChatThreadRead", "NewChatThreadUpdate", "NewChatThreadWithMessages", - # NewLLMConfig schemas - "NewLLMConfigCreate", - "NewLLMConfigPublic", - "NewLLMConfigRead", - "NewLLMConfigUpdate", "PagePurchaseHistoryResponse", "PagePurchaseRead", "PaginatedResponse", @@ -303,8 +262,4 @@ __all__ = [ "VideoPresentationCreate", "VideoPresentationRead", "VideoPresentationUpdate", - "VisionLLMConfigCreate", - "VisionLLMConfigPublic", - "VisionLLMConfigRead", - "VisionLLMConfigUpdate", ] diff --git a/surfsense_backend/app/schemas/image_generation.py b/surfsense_backend/app/schemas/image_generation.py index 4262b2b3f..83671cc77 100644 --- a/surfsense_backend/app/schemas/image_generation.py +++ b/surfsense_backend/app/schemas/image_generation.py @@ -1,109 +1,10 @@ -""" -Pydantic schemas for Image Generation configs and generation requests. +"""Pydantic schemas for image generation requests/results.""" -ImageGenerationConfig: CRUD schemas for user-created image gen model configs. -ImageGeneration: Schemas for the actual image generation requests/results. -GlobalImageGenConfigRead: Schema for admin-configured YAML configs. -""" - -import uuid from datetime import datetime from typing import Any from pydantic import BaseModel, ConfigDict, Field -from app.db import ImageGenProvider - -# ============================================================================= -# ImageGenerationConfig CRUD Schemas -# ============================================================================= - - -class ImageGenerationConfigBase(BaseModel): - """Base schema with fields for ImageGenerationConfig.""" - - name: str = Field( - ..., max_length=100, description="User-friendly name for the config" - ) - description: str | None = Field( - None, max_length=500, description="Optional description" - ) - provider: ImageGenProvider = Field( - ..., - description="Image generation provider (OpenAI, Azure, Google AI Studio, Vertex AI, Bedrock, Recraft, OpenRouter, Xinference, Nscale)", - ) - custom_provider: str | None = Field( - None, max_length=100, description="Custom provider name" - ) - model_name: str = Field( - ..., max_length=100, description="Model name (e.g., dall-e-3, gpt-image-1)" - ) - api_key: str = Field(..., description="API key for the provider") - api_base: str | None = Field( - None, max_length=500, description="Optional API base URL" - ) - api_version: str | None = Field( - None, - max_length=50, - description="Azure-specific API version (e.g., '2024-02-15-preview')", - ) - litellm_params: dict[str, Any] | None = Field( - default=None, description="Additional LiteLLM parameters" - ) - - -class ImageGenerationConfigCreate(ImageGenerationConfigBase): - """Schema for creating a new ImageGenerationConfig.""" - - search_space_id: int = Field( - ..., description="Search space ID to associate the config with" - ) - - -class ImageGenerationConfigUpdate(BaseModel): - """Schema for updating an existing ImageGenerationConfig. All fields optional.""" - - name: str | None = Field(None, max_length=100) - description: str | None = Field(None, max_length=500) - provider: ImageGenProvider | None = None - custom_provider: str | None = Field(None, max_length=100) - model_name: str | None = Field(None, max_length=100) - api_key: str | None = None - api_base: str | None = Field(None, max_length=500) - api_version: str | None = Field(None, max_length=50) - litellm_params: dict[str, Any] | None = None - - -class ImageGenerationConfigRead(ImageGenerationConfigBase): - """Schema for reading an ImageGenerationConfig (includes id and timestamps).""" - - id: int - created_at: datetime - search_space_id: int - user_id: uuid.UUID - - model_config = ConfigDict(from_attributes=True) - - -class ImageGenerationConfigPublic(BaseModel): - """Public schema that hides the API key (for list views).""" - - id: int - name: str - description: str | None = None - provider: ImageGenProvider - custom_provider: str | None = None - model_name: str - api_base: str | None = None - api_version: str | None = None - litellm_params: dict[str, Any] | None = None - created_at: datetime - search_space_id: int - user_id: uuid.UUID - - model_config = ConfigDict(from_attributes=True) - - # ============================================================================= # ImageGeneration (request/result) Schemas # ============================================================================= @@ -136,12 +37,12 @@ class ImageGenerationCreate(BaseModel): search_space_id: int = Field( ..., description="Search space ID to associate the generation with" ) - image_generation_config_id: int | None = Field( + image_gen_model_id: int | None = Field( None, description=( - "Image generation config ID. " - "0 = Auto mode (router), negative = global YAML config, positive = DB config. " - "If not provided, uses the search space's image_generation_config_id preference." + "Image generation model ID. " + "0 = Auto mode, negative = GLOBAL model, positive = BYOK Model row. " + "If not provided, uses the search space's image_gen_model_id preference." ), ) @@ -157,7 +58,7 @@ class ImageGenerationRead(BaseModel): size: str | None = None style: str | None = None response_format: str | None = None - image_generation_config_id: int | None = None + image_gen_model_id: int | None = None response_data: dict[str, Any] | None = None error_message: str | None = None search_space_id: int @@ -204,57 +105,3 @@ class ImageGenerationListRead(BaseModel): image_count=image_count, ) - -# ============================================================================= -# Global Image Gen Config (from YAML) -# ============================================================================= - - -class GlobalImageGenConfigRead(BaseModel): - """ - Schema for reading global image generation configs from YAML. - Global configs have negative IDs. API key is hidden. - ID 0 is reserved for Auto mode (LiteLLM Router load balancing). - - The ``billing_tier`` field allows the frontend to show a Premium/Free - badge and (more importantly) tells the backend whether to debit the - user's premium credit pool when this config is used. ``"free"`` is - the default for backward compatibility — admins must explicitly opt - a global config into ``"premium"``. - """ - - id: int = Field( - ..., - description="Config ID: 0 for Auto mode, negative for global configs", - ) - name: str - description: str | None = None - provider: str - custom_provider: str | None = None - model_name: str - api_base: str | None = None - api_version: str | None = None - litellm_params: dict[str, Any] | None = None - is_global: bool = True - is_auto_mode: bool = False - billing_tier: str = Field( - default="free", - description="'free' or 'premium'. Premium debits the user's premium credit pool (USD-cost-based).", - ) - is_premium: bool = Field( - default=False, - description=( - "Convenience boolean derived server-side from " - "``billing_tier == 'premium'``. The new-chat model selector " - "keys its Free/Premium badge off this field for parity with " - "chat (`GlobalLLMConfigRead.is_premium`)." - ), - ) - quota_reserve_micros: int | None = Field( - default=None, - description=( - "Optional override for the reservation amount (in micro-USD) used when " - "this image generation is premium. Falls back to " - "QUOTA_DEFAULT_IMAGE_RESERVE_MICROS when omitted." - ), - ) diff --git a/surfsense_backend/app/schemas/new_llm_config.py b/surfsense_backend/app/schemas/new_llm_config.py deleted file mode 100644 index 2f04a9e66..000000000 --- a/surfsense_backend/app/schemas/new_llm_config.py +++ /dev/null @@ -1,256 +0,0 @@ -""" -Pydantic schemas for the NewLLMConfig API. - -NewLLMConfig combines model settings with prompt configuration: -- LLM provider, model, API key, etc. -- Configurable system instructions -- Citation toggle -""" - -import uuid -from datetime import datetime -from typing import Any - -from pydantic import BaseModel, ConfigDict, Field - -from app.db import LiteLLMProvider - - -class NewLLMConfigBase(BaseModel): - """Base schema with common fields for NewLLMConfig.""" - - name: str = Field( - ..., max_length=100, description="User-friendly name for the configuration" - ) - description: str | None = Field( - None, max_length=500, description="Optional description" - ) - - # Model Configuration - provider: LiteLLMProvider = Field(..., description="LiteLLM provider type") - custom_provider: str | None = Field( - None, max_length=100, description="Custom provider name when provider is CUSTOM" - ) - model_name: str = Field( - ..., max_length=100, description="Model name without provider prefix" - ) - api_key: str = Field(..., description="API key for the provider") - api_base: str | None = Field( - None, max_length=500, description="Optional API base URL" - ) - litellm_params: dict[str, Any] | None = Field( - default=None, description="Additional LiteLLM parameters" - ) - - # Prompt Configuration - system_instructions: str = Field( - default="", - description="Custom system instructions. Empty string uses default SURFSENSE_SYSTEM_INSTRUCTIONS.", - ) - use_default_system_instructions: bool = Field( - default=True, - description="Whether to use default instructions when system_instructions is empty", - ) - citations_enabled: bool = Field( - default=True, - description="Whether to include citation instructions in the system prompt", - ) - - -class NewLLMConfigCreate(NewLLMConfigBase): - """Schema for creating a new NewLLMConfig.""" - - search_space_id: int = Field( - ..., description="Search space ID to associate the config with" - ) - - -class NewLLMConfigUpdate(BaseModel): - """Schema for updating an existing NewLLMConfig. All fields are optional.""" - - name: str | None = Field(None, max_length=100) - description: str | None = Field(None, max_length=500) - - # Model Configuration - provider: LiteLLMProvider | None = None - custom_provider: str | None = Field(None, max_length=100) - model_name: str | None = Field(None, max_length=100) - api_key: str | None = None - api_base: str | None = Field(None, max_length=500) - litellm_params: dict[str, Any] | None = None - - # Prompt Configuration - system_instructions: str | None = None - use_default_system_instructions: bool | None = None - citations_enabled: bool | None = None - - -class NewLLMConfigRead(NewLLMConfigBase): - """Schema for reading a NewLLMConfig (includes id and timestamps).""" - - id: int - created_at: datetime - search_space_id: int - user_id: uuid.UUID - # Capability flag derived at the API boundary (no DB column). Default - # True matches the conservative-allow stance — a BYOK row that the - # route forgot to augment is not pre-judged. The streaming-task - # safety net is the only place a False actually blocks a request. - supports_image_input: bool = Field( - default=True, - description=( - "Whether the BYOK chat config can accept image inputs. Derived " - "at the route boundary from LiteLLM's authoritative model map " - "(``litellm.supports_vision``) — there is no DB column. " - "Default True is the conservative-allow stance for unknown / " - "unmapped models." - ), - ) - - model_config = ConfigDict(from_attributes=True) - - -class NewLLMConfigPublic(BaseModel): - """ - Public schema for NewLLMConfig that hides the API key. - Used when returning configs in list views or to users who shouldn't see keys. - """ - - id: int - name: str - description: str | None = None - - # Model Configuration (no api_key) - provider: LiteLLMProvider - custom_provider: str | None = None - model_name: str - api_base: str | None = None - litellm_params: dict[str, Any] | None = None - - # Prompt Configuration - system_instructions: str - use_default_system_instructions: bool - citations_enabled: bool - - created_at: datetime - search_space_id: int - user_id: uuid.UUID - # Capability flag derived at the API boundary (see NewLLMConfigRead). - supports_image_input: bool = Field( - default=True, - description=( - "Whether the BYOK chat config can accept image inputs. Derived " - "at the route boundary from LiteLLM's authoritative model map. " - "Default True is the conservative-allow stance." - ), - ) - - model_config = ConfigDict(from_attributes=True) - - -class DefaultSystemInstructionsResponse(BaseModel): - """Response schema for getting default system instructions.""" - - default_system_instructions: str = Field( - ..., description="The default SURFSENSE_SYSTEM_INSTRUCTIONS template" - ) - - -class GlobalNewLLMConfigRead(BaseModel): - """ - Schema for reading global LLM configs from YAML. - Global configs have negative IDs and no search_space_id. - API key is hidden for security. - - ID 0 is reserved for Auto mode which uses LiteLLM Router for load balancing. - """ - - id: int = Field( - ..., - description="Config ID: 0 for Auto mode, negative for global configs", - ) - name: str - description: str | None = None - - # Model Configuration (no api_key) - provider: str # String because YAML doesn't enforce enum, "AUTO" for Auto mode - custom_provider: str | None = None - model_name: str - api_base: str | None = None - litellm_params: dict[str, Any] | None = None - - # Prompt Configuration - system_instructions: str = "" - use_default_system_instructions: bool = True - citations_enabled: bool = True - - is_global: bool = True # Always true for global configs - is_auto_mode: bool = False # True only for Auto mode (ID 0) - - billing_tier: str = "free" - is_premium: bool = False - anonymous_enabled: bool = False - seo_enabled: bool = False - seo_slug: str | None = None - seo_title: str | None = None - seo_description: str | None = None - quota_reserve_tokens: int | None = None - supports_image_input: bool = Field( - default=True, - description=( - "Whether the model accepts image inputs (multimodal vision). " - "Derived server-side: OpenRouter dynamic configs use " - "``architecture.input_modalities``; YAML / BYOK use LiteLLM's " - "authoritative model map (``litellm.supports_vision``). The " - "new-chat selector hints with a 'No image' badge when this is " - "False and there are pending image attachments. The streaming " - "task fails fast only when LiteLLM *explicitly* marks a model " - "as text-only — unknown / unmapped models default-allow." - ), - ) - - -# ============================================================================= -# LLM Preferences Schemas (for role assignments) -# ============================================================================= - - -class LLMPreferencesRead(BaseModel): - """Schema for reading LLM preferences (role assignments) for a search space.""" - - agent_llm_id: int | None = Field( - None, description="ID of the LLM config to use for agent/chat tasks" - ) - image_generation_config_id: int | None = Field( - None, description="ID of the image generation config to use" - ) - vision_llm_config_id: int | None = Field( - None, - description="ID of the vision LLM config to use for vision/screenshot analysis", - ) - agent_llm: dict[str, Any] | None = Field( - None, description="Full config for chat model" - ) - image_generation_config: dict[str, Any] | None = Field( - None, description="Full config for image generation" - ) - vision_llm_config: dict[str, Any] | None = Field( - None, description="Full config for vision LLM" - ) - - model_config = ConfigDict(from_attributes=True) - - -class LLMPreferencesUpdate(BaseModel): - """Schema for updating LLM preferences.""" - - agent_llm_id: int | None = Field( - None, description="ID of the LLM config to use for agent/chat tasks" - ) - image_generation_config_id: int | None = Field( - None, description="ID of the image generation config to use" - ) - vision_llm_config_id: int | None = Field( - None, - description="ID of the vision LLM config to use for vision/screenshot analysis", - ) diff --git a/surfsense_backend/app/schemas/vision_llm.py b/surfsense_backend/app/schemas/vision_llm.py deleted file mode 100644 index d0eeaf5c6..000000000 --- a/surfsense_backend/app/schemas/vision_llm.py +++ /dev/null @@ -1,116 +0,0 @@ -import uuid -from datetime import datetime -from typing import Any - -from pydantic import BaseModel, ConfigDict, Field - -from app.db import VisionProvider - - -class VisionLLMConfigBase(BaseModel): - name: str = Field(..., max_length=100) - description: str | None = Field(None, max_length=500) - provider: VisionProvider = Field(...) - custom_provider: str | None = Field(None, max_length=100) - model_name: str = Field(..., max_length=100) - api_key: str = Field(...) - api_base: str | None = Field(None, max_length=500) - api_version: str | None = Field(None, max_length=50) - litellm_params: dict[str, Any] | None = Field(default=None) - - -class VisionLLMConfigCreate(VisionLLMConfigBase): - search_space_id: int = Field(...) - - -class VisionLLMConfigUpdate(BaseModel): - name: str | None = Field(None, max_length=100) - description: str | None = Field(None, max_length=500) - provider: VisionProvider | None = None - custom_provider: str | None = Field(None, max_length=100) - model_name: str | None = Field(None, max_length=100) - api_key: str | None = None - api_base: str | None = Field(None, max_length=500) - api_version: str | None = Field(None, max_length=50) - litellm_params: dict[str, Any] | None = None - - -class VisionLLMConfigRead(VisionLLMConfigBase): - id: int - created_at: datetime - search_space_id: int - user_id: uuid.UUID - - model_config = ConfigDict(from_attributes=True) - - -class VisionLLMConfigPublic(BaseModel): - id: int - name: str - description: str | None = None - provider: VisionProvider - custom_provider: str | None = None - model_name: str - api_base: str | None = None - api_version: str | None = None - litellm_params: dict[str, Any] | None = None - created_at: datetime - search_space_id: int - user_id: uuid.UUID - - model_config = ConfigDict(from_attributes=True) - - -class GlobalVisionLLMConfigRead(BaseModel): - """Schema for reading global vision LLM configs from YAML. - - The ``billing_tier`` field allows the frontend to show a Premium/Free - badge and (more importantly) tells the backend whether to debit the - user's premium credit pool when this config is used. ``"free"`` is - the default for backward compatibility — admins must explicitly opt - a global config into ``"premium"``. - """ - - id: int = Field(...) - name: str - description: str | None = None - provider: str - custom_provider: str | None = None - model_name: str - api_base: str | None = None - api_version: str | None = None - litellm_params: dict[str, Any] | None = None - is_global: bool = True - is_auto_mode: bool = False - billing_tier: str = Field( - default="free", - description="'free' or 'premium'. Premium debits the user's premium credit pool (USD-cost-based).", - ) - is_premium: bool = Field( - default=False, - description=( - "Convenience boolean derived server-side from " - "``billing_tier == 'premium'``. The new-chat model selector " - "keys its Free/Premium badge off this field for parity with " - "chat (`GlobalLLMConfigRead.is_premium`)." - ), - ) - quota_reserve_tokens: int | None = Field( - default=None, - description=( - "Optional override for the per-call reservation in *tokens* — " - "converted to micro-USD via the model's input/output prices at " - "reservation time. Falls back to QUOTA_DEFAULT_RESERVE_TOKENS." - ), - ) - input_cost_per_token: float | None = Field( - default=None, - description=( - "Optional input price in USD/token. Used by pricing_registration to " - "register custom Azure / OpenRouter aliases with LiteLLM at startup." - ), - ) - output_cost_per_token: float | None = Field( - default=None, - description="Optional output price in USD/token. Pair with input_cost_per_token.", - ) diff --git a/surfsense_backend/app/services/auto_model_pin_service.py b/surfsense_backend/app/services/auto_model_pin_service.py index b4f1bafc9..dfd7c7be3 100644 --- a/surfsense_backend/app/services/auto_model_pin_service.py +++ b/surfsense_backend/app/services/auto_model_pin_service.py @@ -1,13 +1,13 @@ -"""Resolve and persist Auto (Fastest) model pins per chat thread. +"""Resolve and persist Auto model pins per chat thread. -Auto (Fastest) is represented by ``agent_llm_id == 0``. For chat threads we -resolve that virtual mode to one concrete global LLM config exactly once and +Auto is represented by ``chat_model_id == 0``. For chat threads we +resolve that virtual mode to one concrete global model exactly once and persist the chosen config id on ``new_chat_threads.pinned_llm_config_id`` so subsequent turns are stable. Single-writer invariant: this module is the only writer of ``NewChatThread.pinned_llm_config_id`` (aside from the bulk clear in -``search_spaces_routes`` when a search space's ``agent_llm_id`` changes). +``model_connections_routes`` when a search space's ``chat_model_id`` changes). Therefore a non-NULL value unambiguously means "this thread has an Auto-resolved pin"; no separate source/policy column is needed. """ @@ -33,8 +33,10 @@ from app.services.token_quota_service import TokenQuotaService logger = logging.getLogger(__name__) -AUTO_FASTEST_ID = 0 -AUTO_FASTEST_MODE = "auto_fastest" +AUTO_MODE_ID = 0 +# Stable internal hash namespace for deterministic per-thread selection. +# Do not rename: changing this rebalances Auto's model choice for new pins. +AUTO_PIN_HASH_NAMESPACE = "auto_fastest" _RUNTIME_COOLDOWN_SECONDS = 600 _HEALTHY_TTL_SECONDS = 45 @@ -383,7 +385,7 @@ def _select_pin(eligible: list[dict], thread_id: int) -> tuple[dict, int]: pool = tier_a if tier_a else eligible pool = sorted(pool, key=lambda c: -int(c.get("quality_score") or 0)) top_k = pool[:_QUALITY_TOP_K] - digest = hashlib.sha256(f"{AUTO_FASTEST_MODE}:{thread_id}".encode()).digest() + digest = hashlib.sha256(f"{AUTO_PIN_HASH_NAMESPACE}:{thread_id}".encode()).digest() idx = int.from_bytes(digest[:8], "big") % len(top_k) return top_k[idx], len(top_k) @@ -425,7 +427,7 @@ async def resolve_or_get_pinned_llm_config_id( exclude_config_ids: set[int] | None = None, requires_image_input: bool = False, ) -> AutoPinResolution: - """Resolve Auto (Fastest) to one concrete config id and persist the pin. + """Resolve Auto to one concrete config id and persist the pin. For non-auto selections, this function clears any existing pin and returns the selected id as-is. @@ -457,7 +459,7 @@ async def resolve_or_get_pinned_llm_config_id( ) # Explicit model selected: clear any stale pin. - if selected_llm_config_id != AUTO_FASTEST_ID: + if selected_llm_config_id != AUTO_MODE_ID: if thread.pinned_llm_config_id is not None: thread.pinned_llm_config_id = None await session.commit() diff --git a/surfsense_backend/app/services/billable_calls.py b/surfsense_backend/app/services/billable_calls.py index f21f52e14..15a3c3e55 100644 --- a/surfsense_backend/app/services/billable_calls.py +++ b/surfsense_backend/app/services/billable_calls.py @@ -450,10 +450,10 @@ async def _resolve_agent_billing_for_search_space( Used by Celery tasks (podcast generation, video presentation) to bill the search-space owner's premium credit pool when the chat model is premium. - Resolution rules mirror chat at ``stream_new_chat.py:2294-2351``: + Resolution rules mirror the chat model role resolver: - - Search space not found / no ``agent_llm_id``: raise ``ValueError``. - - **Auto mode** (``id == AUTO_FASTEST_ID == 0``): + - Search space not found / no ``chat_model_id``: raise ``ValueError``. + - **Auto mode** (``id == AUTO_MODE_ID == 0``): * ``thread_id`` is set: delegate to ``resolve_or_get_pinned_llm_config_id`` (the same call chat uses) and recurse into the resolved id. Reuses chat's existing pin if present @@ -469,9 +469,8 @@ async def _resolve_agent_billing_for_search_space( (defaults to ``"free"`` via ``app/config/__init__.py:52`` setdefault), ``base_model = litellm_params.get("base_model") or model_name`` — NOT provider-prefixed, matching chat's cost-map lookup convention. - - **Positive id** (user BYOK ``NewLLMConfig``): always free (matches - ``AgentConfig.from_new_llm_config`` which hard-codes ``billing_tier="free"``); - ``base_model`` from ``litellm_params`` or ``model_name``. + - **Positive id** (user BYOK ``Model``): always free; ``base_model`` from + the model catalog override or the upstream ``model_id``. Note on imports: ``llm_service``, ``auto_model_pin_service``, and ``llm_router_service`` are imported lazily inside the function body to @@ -480,8 +479,9 @@ async def _resolve_agent_billing_for_search_space( ``billable_calls.py``'s module load path. """ from sqlalchemy import select + from sqlalchemy.orm import selectinload - from app.db import NewLLMConfig, SearchSpace + from app.db import Model, SearchSpace result = await session.execute( select(SearchSpace).where(SearchSpace.id == search_space_id) @@ -490,20 +490,20 @@ async def _resolve_agent_billing_for_search_space( if search_space is None: raise ValueError(f"Search space {search_space_id} not found") - agent_llm_id = search_space.agent_llm_id - if agent_llm_id is None: + chat_model_id = search_space.chat_model_id + if chat_model_id is None: raise ValueError( - f"Search space {search_space_id} has no agent_llm_id configured" + f"Search space {search_space_id} has no chat_model_id configured" ) owner_user_id: UUID = search_space.user_id from app.services.auto_model_pin_service import ( - AUTO_FASTEST_ID, + AUTO_MODE_ID, resolve_or_get_pinned_llm_config_id, ) - if agent_llm_id == AUTO_FASTEST_ID: + if chat_model_id == AUTO_MODE_ID: if thread_id is None: return owner_user_id, "free", "auto" try: @@ -512,7 +512,7 @@ async def _resolve_agent_billing_for_search_space( thread_id=thread_id, search_space_id=search_space_id, user_id=str(owner_user_id), - selected_llm_config_id=AUTO_FASTEST_ID, + selected_llm_config_id=AUTO_MODE_ID, ) except ValueError: logger.warning( @@ -523,28 +523,35 @@ async def _resolve_agent_billing_for_search_space( exc_info=True, ) return owner_user_id, "free", "auto" - agent_llm_id = resolution.resolved_llm_config_id + chat_model_id = resolution.resolved_llm_config_id - if agent_llm_id < 0: + if chat_model_id < 0: from app.services.llm_service import get_global_llm_config - cfg = get_global_llm_config(agent_llm_id) or {} + cfg = get_global_llm_config(chat_model_id) or {} billing_tier = str(cfg.get("billing_tier", "free")).lower() litellm_params = cfg.get("litellm_params") or {} base_model = litellm_params.get("base_model") or cfg.get("model_name") or "" return owner_user_id, billing_tier, base_model - nlc_result = await session.execute( - select(NewLLMConfig).where( - NewLLMConfig.id == agent_llm_id, - NewLLMConfig.search_space_id == search_space_id, - ) + model_result = await session.execute( + select(Model) + .options(selectinload(Model.connection)) + .where(Model.id == chat_model_id, Model.enabled.is_(True)) ) - nlc = nlc_result.scalars().first() + model = model_result.scalars().first() base_model = "" - if nlc is not None: - litellm_params = nlc.litellm_params or {} - base_model = litellm_params.get("base_model") or nlc.model_name or "" + if ( + model is not None + and model.connection is not None + and model.connection.enabled + and ( + model.connection.search_space_id in (None, search_space_id) + and model.connection.user_id in (None, owner_user_id) + ) + ): + catalog = model.catalog or {} + base_model = catalog.get("base_model") or model.model_id or "" return owner_user_id, "free", base_model diff --git a/surfsense_backend/app/services/llm_service.py b/surfsense_backend/app/services/llm_service.py index eadb4dbf8..277929e96 100644 --- a/surfsense_backend/app/services/llm_service.py +++ b/surfsense_backend/app/services/llm_service.py @@ -14,7 +14,11 @@ from app.services.auto_model_pin_service import ( auto_model_candidates, choose_auto_model_candidate, ) -from app.services.llm_router_service import AUTO_MODE_ID, ChatLiteLLMRouter, is_auto_mode +from app.services.llm_router_service import ( + AUTO_MODE_ID, + ChatLiteLLMRouter, + is_auto_mode, +) from app.services.model_capabilities import has_capability from app.services.model_resolver import native_connection_from_config, to_litellm from app.services.token_tracking_service import token_tracker @@ -96,26 +100,16 @@ class LLMRole: def get_global_llm_config(llm_config_id: int) -> dict | None: """ Get a global LLM configuration by ID. - Global configs have negative IDs. ID 0 is reserved for Auto mode. + Global configs have negative IDs. Auto mode (ID 0) is resolved through the + model-candidate pipeline, not this legacy config lookup. Args: - llm_config_id: The ID of the global config (should be negative or 0 for Auto) + llm_config_id: The ID of the global config (must be negative) Returns: dict: Global config dictionary or None if not found """ - # Auto mode (ID 0) is handled separately via the router - if llm_config_id == AUTO_MODE_ID: - return { - "id": AUTO_MODE_ID, - "name": "Auto (Fastest)", - "description": "Automatically routes requests across available LLM providers for optimal performance and rate limit handling", - "provider": "AUTO", - "model_name": "auto", - "is_auto_mode": True, - } - - if llm_config_id > 0: + if llm_config_id >= 0: return None for cfg in config.GLOBAL_LLM_CONFIGS: diff --git a/surfsense_backend/app/services/model_list_service.py b/surfsense_backend/app/services/model_list_service.py index 0761d7e4f..ffb430756 100644 --- a/surfsense_backend/app/services/model_list_service.py +++ b/surfsense_backend/app/services/model_list_service.py @@ -24,7 +24,7 @@ CACHE_TTL_SECONDS = 86400 # 24 hours _cache: list[dict] | None = None _cache_timestamp: float = 0 -# Maps OpenRouter provider slug → our LiteLLMProvider enum value. +# Maps OpenRouter provider slug to native LiteLLM provider prefixes. # Only providers where the model-name part (after the slash) can be # used directly with the native provider's litellm prefix are listed. # diff --git a/surfsense_backend/app/services/openrouter_integration_service.py b/surfsense_backend/app/services/openrouter_integration_service.py index fbb70eb5a..8f4c4cb5f 100644 --- a/surfsense_backend/app/services/openrouter_integration_service.py +++ b/surfsense_backend/app/services/openrouter_integration_service.py @@ -281,7 +281,7 @@ def _generate_configs( OpenRouter's own ``openrouter/free`` meta-router is filtered out upstream via ``_EXCLUDED_MODEL_IDS``; we don't expose a redundant auto-select layer - because our own Auto (Fastest) pin + 24 h refresh + repair logic already + because our own Auto pin + 24 h refresh + repair logic already cover the catalogue-churn case. """ id_offset: int = settings.get("id_offset", -10000) @@ -346,7 +346,7 @@ def _generate_configs( # ``"No endpoints found that support image input"``. "supports_image_input": bool(normalized.get("supports_image_input")), _OPENROUTER_DYNAMIC_MARKER: True, - # Auto (Fastest) ranking metadata. ``quality_score`` is initialised + # Auto ranking metadata. ``quality_score`` is initialised # to the static score and gets re-blended with health on the next # ``_enrich_health`` pass (synchronous on refresh, deferred on cold # start so startup latency is unchanged). @@ -361,11 +361,7 @@ def _generate_configs( return configs -# ID-offset bands used to keep dynamic OpenRouter configs in their own -# namespace per surface. Image / vision get separate bands so a single -# Postgres-INTEGER cfg ID is unambiguous about which selector it belongs to. _OPENROUTER_IMAGE_ID_OFFSET_DEFAULT = -20000 -_OPENROUTER_VISION_ID_OFFSET_DEFAULT = -30000 def _generate_image_gen_configs( @@ -431,89 +427,6 @@ def _generate_image_gen_configs( return configs -def _generate_vision_llm_configs( - raw_models: list[dict], settings: dict[str, Any] -) -> list[dict]: - """Convert OpenRouter vision-capable LLMs into global vision-LLM config - dicts (matches the YAML shape consumed by ``vision_llm_routes``). - - Filter: - - architecture.input_modalities contains "image" - - architecture.output_modalities contains "text" - - compatible provider (excluded slugs blocked) - - allowed model id (excluded list blocked) - - Vision-LLM is invoked from the indexer (image extraction during - document upload) via ``langchain_litellm.ChatLiteLLM.ainvoke``, so - the chat-only ``_supports_tool_calling`` and ``_has_sufficient_context`` - filters do not apply: a small-context vision model that doesn't - advertise tool-calling is still perfectly viable for "describe this - image" prompts. - """ - id_offset: int = int( - settings.get("vision_id_offset") or _OPENROUTER_VISION_ID_OFFSET_DEFAULT - ) - api_key: str = settings.get("api_key", "") - rpm: int = settings.get("rpm", 200) - tpm: int = settings.get("tpm", 1_000_000) - free_rpm: int = settings.get("free_rpm", 20) - free_tpm: int = settings.get("free_tpm", 100_000) - quota_reserve_tokens: int = settings.get("quota_reserve_tokens", 4000) - litellm_params: dict = settings.get("litellm_params") or {} - - vision_models = [ - m - for m in raw_models - if supports_image_input(m) - and _shared_is_compatible_provider(m) - and _shared_is_allowed_model(m) - and "/" in m.get("id", "") - ] - - configs: list[dict] = [] - taken: set[int] = set() - for model in vision_models: - model_id: str = model["id"] - name: str = model.get("name", model_id) - tier = _openrouter_tier(model) - pricing = model.get("pricing") or {} - - # Capture per-token prices so ``pricing_registration`` can - # register them with LiteLLM at startup (and so the cost - # estimator in ``estimate_call_reserve_micros`` can resolve - # them at reserve time). - try: - input_cost = float(pricing.get("prompt", 0) or 0) - except (TypeError, ValueError): - input_cost = 0.0 - try: - output_cost = float(pricing.get("completion", 0) or 0) - except (TypeError, ValueError): - output_cost = 0.0 - - cfg: dict[str, Any] = { - "id": _stable_config_id(model_id, id_offset, taken), - "name": name, - "description": f"{name} via OpenRouter (vision)", - "provider": "openrouter", - "model_name": model_id, - "api_key": api_key, - "api_base": "https://openrouter.ai/api/v1", - "api_version": None, - "rpm": free_rpm if tier == "free" else rpm, - "tpm": free_tpm if tier == "free" else tpm, - "litellm_params": dict(litellm_params), - "billing_tier": tier, - "quota_reserve_tokens": quota_reserve_tokens, - "input_cost_per_token": input_cost or None, - "output_cost_per_token": output_cost or None, - _OPENROUTER_DYNAMIC_MARKER: True, - } - configs.append(cfg) - - return configs - - class OpenRouterIntegrationService: """Singleton that manages the dynamic OpenRouter model catalogue.""" @@ -724,7 +637,7 @@ class OpenRouterIntegrationService: return counts # ------------------------------------------------------------------ - # Auto (Fastest) health enrichment + # Auto health enrichment # ------------------------------------------------------------------ async def _enrich_health_safely( diff --git a/surfsense_backend/app/services/pricing_registration.py b/surfsense_backend/app/services/pricing_registration.py index 9e4e3b552..7343df737 100644 --- a/surfsense_backend/app/services/pricing_registration.py +++ b/surfsense_backend/app/services/pricing_registration.py @@ -154,10 +154,8 @@ def _register_chat_shape_configs( input_cost = _safe_float(entry.get("prompt")) output_cost = _safe_float(entry.get("completion")) else: - # Vision configs from ``_generate_vision_llm_configs`` - # carry their pricing inline because the OpenRouter - # raw-pricing cache is keyed by chat-catalogue model_id; - # vision flows pick up the inline values here. + # Some dynamically materialized configs can carry pricing + # inline when the raw OpenRouter cache has no matching entry. input_cost = _safe_float(cfg.get("input_cost_per_token")) output_cost = _safe_float(cfg.get("output_cost_per_token")) if input_cost == 0.0 and output_cost == 0.0: diff --git a/surfsense_backend/app/services/quality_score.py b/surfsense_backend/app/services/quality_score.py index 9cc9c21ac..737dd7c2f 100644 --- a/surfsense_backend/app/services/quality_score.py +++ b/surfsense_backend/app/services/quality_score.py @@ -1,4 +1,4 @@ -"""Pure-function quality scoring for Auto (Fastest) model selection. +"""Pure-function quality scoring for Auto model selection. This module is import-free of any service / request-path dependencies. All numbers are computed once during the OpenRouter refresh tick (or YAML load) diff --git a/surfsense_backend/app/services/vision_llm_router_service.py b/surfsense_backend/app/services/vision_llm_router_service.py deleted file mode 100644 index 0ff716324..000000000 --- a/surfsense_backend/app/services/vision_llm_router_service.py +++ /dev/null @@ -1,160 +0,0 @@ -import logging -from typing import Any - -from litellm import Router - -from app.services.model_resolver import native_connection_from_config, to_litellm - -logger = logging.getLogger(__name__) - -VISION_AUTO_MODE_ID = 0 - -class VisionLLMRouterService: - _instance = None - _router: Router | None = None - _model_list: list[dict] = [] - _router_settings: dict = {} - _initialized: bool = False - - def __new__(cls): - if cls._instance is None: - cls._instance = super().__new__(cls) - return cls._instance - - @classmethod - def get_instance(cls) -> "VisionLLMRouterService": - if cls._instance is None: - cls._instance = cls() - return cls._instance - - @classmethod - def initialize( - cls, - global_configs: list[dict], - router_settings: dict | None = None, - ) -> None: - instance = cls.get_instance() - - if instance._initialized: - logger.debug("Vision LLM Router already initialized, skipping") - return - - model_list = [] - for config in global_configs: - deployment = cls._config_to_deployment(config) - if deployment: - model_list.append(deployment) - - if not model_list: - logger.warning( - "No valid vision LLM configs found for router initialization" - ) - return - - instance._model_list = model_list - instance._router_settings = router_settings or {} - - default_settings = { - "routing_strategy": "usage-based-routing", - "num_retries": 3, - "allowed_fails": 3, - "cooldown_time": 60, - "retry_after": 5, - } - - final_settings = {**default_settings, **instance._router_settings} - - try: - instance._router = Router( - model_list=model_list, - routing_strategy=final_settings.get( - "routing_strategy", "usage-based-routing" - ), - num_retries=final_settings.get("num_retries", 3), - allowed_fails=final_settings.get("allowed_fails", 3), - cooldown_time=final_settings.get("cooldown_time", 60), - set_verbose=False, - ) - instance._initialized = True - logger.info( - "Vision LLM Router initialized with %d deployments, strategy: %s", - len(model_list), - final_settings.get("routing_strategy"), - ) - except Exception as e: - logger.error(f"Failed to initialize Vision LLM Router: {e}") - instance._router = None - - @classmethod - def _config_to_deployment(cls, config: dict) -> dict | None: - try: - if not config.get("model_name") or not config.get("api_key"): - return None - - model_string, resolved_kwargs = to_litellm( - native_connection_from_config(config), - config["model_name"], - ) - litellm_params: dict[str, Any] = {"model": model_string, **resolved_kwargs} - - deployment: dict[str, Any] = { - "model_name": "auto", - "litellm_params": litellm_params, - } - - if config.get("rpm"): - deployment["rpm"] = config["rpm"] - if config.get("tpm"): - deployment["tpm"] = config["tpm"] - - return deployment - - except Exception as e: - logger.warning(f"Failed to convert vision config to deployment: {e}") - return None - - @classmethod - def get_router(cls) -> Router | None: - instance = cls.get_instance() - return instance._router - - @classmethod - def is_initialized(cls) -> bool: - instance = cls.get_instance() - return instance._initialized and instance._router is not None - - @classmethod - def get_model_count(cls) -> int: - instance = cls.get_instance() - return len(instance._model_list) - - -def is_vision_auto_mode(config_id: int | None) -> bool: - return config_id == VISION_AUTO_MODE_ID - - -def build_vision_model_string( - litellm_provider: str, model_name: str, custom_provider: str | None -) -> str: - if custom_provider: - return f"{custom_provider}/{model_name}" - return f"{litellm_provider}/{model_name}" - - -def get_global_vision_llm_config(config_id: int) -> dict | None: - from app.config import config - - if config_id == VISION_AUTO_MODE_ID: - return { - "id": VISION_AUTO_MODE_ID, - "name": "Auto (Fastest)", - "provider": "AUTO", - "model_name": "auto", - "is_auto_mode": True, - } - if config_id > 0: - return None - for cfg in config.GLOBAL_VISION_LLM_CONFIGS: - if cfg.get("id") == config_id: - return cfg - return None diff --git a/surfsense_backend/app/services/vision_model_list_service.py b/surfsense_backend/app/services/vision_model_list_service.py deleted file mode 100644 index 6eae8c455..000000000 --- a/surfsense_backend/app/services/vision_model_list_service.py +++ /dev/null @@ -1,134 +0,0 @@ -""" -Service for fetching and caching the vision-capable model list. - -Reuses the same OpenRouter public API and local fallback as the LLM model -list service, but filters for models that accept image input. -""" - -import json -import logging -import time -from pathlib import Path - -import httpx - -logger = logging.getLogger(__name__) - -OPENROUTER_API_URL = "https://openrouter.ai/api/v1/models" -FALLBACK_FILE = ( - Path(__file__).parent.parent / "config" / "vision_model_list_fallback.json" -) -CACHE_TTL_SECONDS = 86400 # 24 hours - -_cache: list[dict] | None = None -_cache_timestamp: float = 0 - -OPENROUTER_SLUG_TO_VISION_PROVIDER: dict[str, str] = { - "openai": "OPENAI", - "anthropic": "ANTHROPIC", - "google": "GOOGLE", - "mistralai": "MISTRAL", - "x-ai": "XAI", -} - - -def _format_context_length(length: int | None) -> str | None: - if not length: - return None - if length >= 1_000_000: - return f"{length / 1_000_000:g}M" - if length >= 1_000: - return f"{length / 1_000:g}K" - return str(length) - - -async def _fetch_from_openrouter() -> list[dict] | None: - try: - async with httpx.AsyncClient(timeout=15) as client: - response = await client.get(OPENROUTER_API_URL) - response.raise_for_status() - data = response.json() - return data.get("data", []) - except Exception as e: - logger.warning("Failed to fetch from OpenRouter API for vision models: %s", e) - return None - - -def _load_fallback() -> list[dict]: - try: - with open(FALLBACK_FILE, encoding="utf-8") as f: - return json.load(f) - except Exception as e: - logger.error("Failed to load vision model fallback list: %s", e) - return [] - - -def _is_vision_model(model: dict) -> bool: - """Return True if the model accepts image input and outputs text.""" - arch = model.get("architecture", {}) - input_mods = arch.get("input_modalities", []) - output_mods = arch.get("output_modalities", []) - return "image" in input_mods and "text" in output_mods - - -def _process_vision_models(raw_models: list[dict]) -> list[dict]: - processed: list[dict] = [] - - for model in raw_models: - model_id: str = model.get("id", "") - name: str = model.get("name", "") - context_length = model.get("context_length") - - if "/" not in model_id: - continue - - if not _is_vision_model(model): - continue - - provider_slug, model_name = model_id.split("/", 1) - context_window = _format_context_length(context_length) - - processed.append( - { - "value": model_id, - "label": name, - "provider": "OPENROUTER", - "context_window": context_window, - } - ) - - direct_provider = OPENROUTER_SLUG_TO_VISION_PROVIDER.get(provider_slug) - if direct_provider: - if direct_provider == "GOOGLE" and not model_name.startswith("gemini-"): - continue - - processed.append( - { - "value": model_name, - "label": name, - "provider": direct_provider, - "context_window": context_window, - } - ) - - return processed - - -async def get_vision_model_list() -> list[dict]: - global _cache, _cache_timestamp - - if _cache is not None and (time.time() - _cache_timestamp) < CACHE_TTL_SECONDS: - return _cache - - raw_models = await _fetch_from_openrouter() - - if raw_models is None: - logger.info("Using fallback vision model list") - return _load_fallback() - - processed = _process_vision_models(raw_models) - - _cache = processed - _cache_timestamp = time.time() - - return processed diff --git a/surfsense_backend/scripts/verify_chat_image_capability.py b/surfsense_backend/scripts/verify_chat_image_capability.py index 6e711f99a..e6a535711 100644 --- a/surfsense_backend/scripts/verify_chat_image_capability.py +++ b/surfsense_backend/scripts/verify_chat_image_capability.py @@ -330,31 +330,6 @@ async def probe_chat_configs(report: Report, *, live: bool) -> None: report.add(result) -async def probe_vision_configs(report: Report, *, live: bool) -> None: - print("\n[vision configs from global_vision_llm_configs (YAML-static)]") - for cfg in config.GLOBAL_VISION_LLM_CONFIGS: - if _is_or_dynamic(cfg): - continue - result = ProbeResult( - label=str(cfg.get("name") or cfg.get("model_name")), - surface="vision", - config_id=cfg.get("id"), - ) - # For vision configs, capability is implied — they're in the - # dedicated vision pool. Run the same resolver to flag any - # surprise disagreement. - cap_ok, cap_note = _probe_chat_capability(cfg) - result.capability_ok = cap_ok - result.capability_note = cap_note - if live: - t0 = time.perf_counter() - ok, note = await _live_chat_image_call(cfg) - result.live_ok = ok - result.live_note = note - result.duration_s = time.perf_counter() - t0 - report.add(result) - - async def probe_image_gen_configs(report: Report, *, live: bool) -> None: print( "\n[image generation configs from global_image_generation_configs (YAML-static)]" @@ -380,7 +355,7 @@ async def probe_image_gen_configs(report: Report, *, live: bool) -> None: async def probe_openrouter_catalog(report: Report, *, live: bool) -> None: - """Sample one chat (vision-capable), one vision, one image-gen model + """Sample chat/vision-capable and image-gen models from the live OpenRouter catalogue. Doesn't iterate the full pool (would be hundreds of probes); just validates the integration end- to-end on a representative model from each surface.""" @@ -405,9 +380,6 @@ async def probe_openrouter_catalog(report: Report, *, live: bool) -> None: for c in config.GLOBAL_LLM_CONFIGS if c.get("provider") == "OPENROUTER" and c.get("supports_image_input") ] - or_vision = [ - c for c in config.GLOBAL_VISION_LLM_CONFIGS if c.get("provider") == "OPENROUTER" - ] or_image_gen = [ c for c in config.GLOBAL_IMAGE_GEN_CONFIGS if c.get("provider") == "OPENROUTER" ] @@ -427,11 +399,6 @@ async def probe_openrouter_catalog(report: Report, *, live: bool) -> None: ("or-chat", _pick_first(or_chat, "anthropic/claude")), ("or-chat", _pick_first(or_chat, "google/gemini-2.5-flash")), ] - vision_picks = [ - ("or-vision", _pick_first(or_vision, "openai/gpt-4o")), - ("or-vision", _pick_first(or_vision, "anthropic/claude")), - ("or-vision", _pick_first(or_vision, "google/gemini-2.5-flash")), - ] image_picks = [ ("or-image", _pick_first(or_image_gen, "google/gemini-2.5-flash-image")), # OpenRouter publishes OpenAI image models as ``openai/gpt-5-image*`` @@ -441,11 +408,11 @@ async def probe_openrouter_catalog(report: Report, *, live: bool) -> None: ] print( - f" catalog: chat={len(or_chat)} vision={len(or_vision)} image_gen={len(or_image_gen)} " + f" catalog: chat_vision={len(or_chat)} image_gen={len(or_image_gen)} " f"(service initialized={service.is_initialized() if hasattr(service, 'is_initialized') else 'n/a'})" ) - for surface, picked in chat_picks + vision_picks + image_picks: + for surface, picked in chat_picks + image_picks: if not picked: report.add( ProbeResult( @@ -486,7 +453,6 @@ async def probe_openrouter_catalog(report: Report, *, live: bool) -> None: async def main(args: argparse.Namespace) -> int: print("Loaded global configs:") print(f" chat: {len(config.GLOBAL_LLM_CONFIGS)} entries") - print(f" vision: {len(config.GLOBAL_VISION_LLM_CONFIGS)} entries") print(f" image-gen: {len(config.GLOBAL_IMAGE_GEN_CONFIGS)} entries") print(f" OR settings present: {bool(config.OPENROUTER_INTEGRATION_SETTINGS)}") @@ -507,8 +473,6 @@ async def main(args: argparse.Namespace) -> int: report = Report() if not args.skip_chat: await probe_chat_configs(report, live=args.live) - if not args.skip_vision: - await probe_vision_configs(report, live=args.live) if not args.skip_image_gen: await probe_image_gen_configs(report, live=args.live) if not args.skip_openrouter: @@ -528,7 +492,6 @@ def _parse_args() -> argparse.Namespace: ) parser.set_defaults(live=True) parser.add_argument("--skip-chat", action="store_true") - parser.add_argument("--skip-vision", action="store_true") parser.add_argument("--skip-image-gen", action="store_true") parser.add_argument("--skip-openrouter", action="store_true") return parser.parse_args() diff --git a/surfsense_backend/tests/unit/automations/actions/builtin/agent_task/test_dependencies.py b/surfsense_backend/tests/unit/automations/actions/builtin/agent_task/test_dependencies.py index 79da12933..f5709e517 100644 --- a/surfsense_backend/tests/unit/automations/actions/builtin/agent_task/test_dependencies.py +++ b/surfsense_backend/tests/unit/automations/actions/builtin/agent_task/test_dependencies.py @@ -1,6 +1,6 @@ """Lock the runtime model-policy backstop in ``build_dependencies``. -Automations resolve their LLM from the *captured* ``agent_llm_id`` snapshot (so +Automations resolve their LLM from the *captured* ``chat_model_id`` snapshot (so runs are insulated from later chat/search-space model changes), and the model policy is re-checked at run time so a captured model that is no longer billable fails the run clearly. When no snapshot is present, resolution falls back to the @@ -45,10 +45,10 @@ def patched_side_effects(monkeypatch: pytest.MonkeyPatch): return None -async def test_build_dependencies_resolves_captured_agent_llm_id( +async def test_build_dependencies_resolves_captured_chat_model_id( monkeypatch: pytest.MonkeyPatch, patched_side_effects ) -> None: - """The bundle loads with the *captured* ``agent_llm_id``, not the live search space.""" + """The bundle loads with the *captured* ``chat_model_id``, not the live search space.""" captured: dict[str, Any] = {} async def _fake_load(_session, *, config_id, search_space_id): @@ -67,13 +67,13 @@ async def test_build_dependencies_resolves_captured_agent_llm_id( lambda _ss: pytest.fail("search-space policy should not run on captured path"), ) - search_space = SimpleNamespace(agent_llm_id=-99) + search_space = SimpleNamespace(chat_model_id=-99) result = await build_dependencies( session=_FakeSession(search_space), search_space_id=42, - agent_llm_id=-7, - image_generation_config_id=5, - vision_llm_config_id=-1, + chat_model_id=-7, + image_gen_model_id=5, + vision_model_id=-1, ) assert captured == {"config_id": -7, "search_space_id": 42} @@ -98,17 +98,17 @@ async def test_build_dependencies_validates_captured_ids( monkeypatch.setattr(deps_mod, "load_llm_bundle", _fake_load) await build_dependencies( - session=_FakeSession(SimpleNamespace(agent_llm_id=0)), + session=_FakeSession(SimpleNamespace(chat_model_id=0)), search_space_id=42, - agent_llm_id=-7, - image_generation_config_id=5, - vision_llm_config_id=-1, + chat_model_id=-7, + image_gen_model_id=5, + vision_model_id=-1, ) assert seen == { - "agent_llm_id": -7, - "image_generation_config_id": 5, - "vision_llm_config_id": -1, + "chat_model_id": -7, + "image_gen_model_id": 5, + "vision_model_id": -1, } @@ -119,7 +119,7 @@ async def test_build_dependencies_raises_on_captured_policy_violation( def _raise(**_kw): raise AutomationModelPolicyError( - [{"kind": "image", "config_id": -2, "reason": "free model"}] + [{"kind": "image", "model_id": -2, "reason": "free model"}] ) monkeypatch.setattr(deps_mod, "assert_models_billable", _raise) @@ -131,11 +131,11 @@ async def test_build_dependencies_raises_on_captured_policy_violation( with pytest.raises(DependencyError): await build_dependencies( - session=_FakeSession(SimpleNamespace(agent_llm_id=-7)), + session=_FakeSession(SimpleNamespace(chat_model_id=-7)), search_space_id=42, - agent_llm_id=-7, - image_generation_config_id=-2, - vision_llm_config_id=-1, + chat_model_id=-7, + image_gen_model_id=-2, + vision_model_id=-1, ) @@ -157,7 +157,7 @@ async def test_build_dependencies_falls_back_to_search_space( lambda **_kw: pytest.fail("captured policy should not run on fallback path"), ) - search_space = SimpleNamespace(agent_llm_id=-7) + search_space = SimpleNamespace(chat_model_id=-7) result = await build_dependencies( session=_FakeSession(search_space), search_space_id=42 ) diff --git a/surfsense_backend/tests/unit/automations/runtime/test_executor_action_ctx.py b/surfsense_backend/tests/unit/automations/runtime/test_executor_action_ctx.py index d7e3c4a0c..c89624fbf 100644 --- a/surfsense_backend/tests/unit/automations/runtime/test_executor_action_ctx.py +++ b/surfsense_backend/tests/unit/automations/runtime/test_executor_action_ctx.py @@ -28,9 +28,9 @@ def _run() -> SimpleNamespace: def test_build_action_ctx_propagates_captured_models() -> None: """``definition.models`` flows onto the ActionContext model fields.""" models = AutomationModels( - agent_llm_id=-1, - image_generation_config_id=5, - vision_llm_config_id=-1, + chat_model_id=-1, + image_gen_model_id=5, + vision_model_id=-1, ) ctx = _build_action_ctx( cast(AsyncSession, None), @@ -40,9 +40,9 @@ def test_build_action_ctx_propagates_captured_models() -> None: ) assert ctx.search_space_id == 42 - assert ctx.agent_llm_id == -1 - assert ctx.image_generation_config_id == 5 - assert ctx.vision_llm_config_id == -1 + assert ctx.chat_model_id == -1 + assert ctx.image_gen_model_id == 5 + assert ctx.vision_model_id == -1 def test_build_action_ctx_none_models_leaves_fields_none() -> None: @@ -54,6 +54,6 @@ def test_build_action_ctx_none_models_leaves_fields_none() -> None: None, ) - assert ctx.agent_llm_id is None - assert ctx.image_generation_config_id is None - assert ctx.vision_llm_config_id is None + assert ctx.chat_model_id is None + assert ctx.image_gen_model_id is None + assert ctx.vision_model_id is None diff --git a/surfsense_backend/tests/unit/automations/schemas/definition/test_envelope.py b/surfsense_backend/tests/unit/automations/schemas/definition/test_envelope.py index 25e193ffa..dc7221b11 100644 --- a/surfsense_backend/tests/unit/automations/schemas/definition/test_envelope.py +++ b/surfsense_backend/tests/unit/automations/schemas/definition/test_envelope.py @@ -40,24 +40,24 @@ def test_automation_definition_models_round_trip() -> None: name="Daily digest", plan=[PlanStep(step_id="s1", action="agent_task")], models=AutomationModels( - agent_llm_id=-1, - image_generation_config_id=5, - vision_llm_config_id=-1, + chat_model_id=-1, + image_gen_model_id=5, + vision_model_id=-1, ), ) dumped = definition.model_dump(mode="json", by_alias=True) assert dumped["models"] == { - "agent_llm_id": -1, - "image_generation_config_id": 5, - "vision_llm_config_id": -1, + "chat_model_id": -1, + "image_gen_model_id": 5, + "vision_model_id": -1, } restored = AutomationDefinition.model_validate(dumped) assert restored.models is not None - assert restored.models.agent_llm_id == -1 - assert restored.models.image_generation_config_id == 5 - assert restored.models.vision_llm_config_id == -1 + assert restored.models.chat_model_id == -1 + assert restored.models.image_gen_model_id == 5 + assert restored.models.vision_model_id == -1 def test_automation_definition_rejects_unknown_top_level_field() -> None: diff --git a/surfsense_backend/tests/unit/automations/services/test_automation_service_policy.py b/surfsense_backend/tests/unit/automations/services/test_automation_service_policy.py index 0bbff39dc..c97dec6a2 100644 --- a/surfsense_backend/tests/unit/automations/services/test_automation_service_policy.py +++ b/surfsense_backend/tests/unit/automations/services/test_automation_service_policy.py @@ -64,12 +64,12 @@ async def test_assert_models_billable_raises_422_on_violation( def _raise(_ss): raise AutomationModelPolicyError( - [{"kind": "llm", "config_id": 0, "reason": "Auto mode"}] + [{"kind": "llm", "model_id": 0, "reason": "Auto mode"}] ) monkeypatch.setattr(automation_mod, "assert_automation_models_billable", _raise) - service = _service(SimpleNamespace(agent_llm_id=0)) + service = _service(SimpleNamespace(chat_model_id=0)) with pytest.raises(HTTPException) as exc_info: await service._assert_models_billable(1) @@ -99,7 +99,7 @@ async def test_assert_models_billable_returns_search_space_when_ok( automation_mod, "assert_automation_models_billable", lambda _ss: None ) - search_space = SimpleNamespace(agent_llm_id=-1) + search_space = SimpleNamespace(chat_model_id=-1) service = _service(search_space) assert await service._assert_models_billable(1) is search_space @@ -123,9 +123,9 @@ async def test_create_injects_captured_models_from_search_space( monkeypatch.setattr(AutomationService, "_get_with_triggers_or_raise", _return_added) search_space = SimpleNamespace( - agent_llm_id=-1, - image_generation_config_id=5, - vision_llm_config_id=-1, + chat_model_id=-1, + image_gen_model_id=5, + vision_model_id=-1, ) service = _service(search_space) payload = AutomationCreate( @@ -137,9 +137,9 @@ async def test_create_injects_captured_models_from_search_space( automation = await service.create(payload) assert automation.definition["models"] == { - "agent_llm_id": -1, - "image_generation_config_id": 5, - "vision_llm_config_id": -1, + "chat_model_id": -1, + "image_gen_model_id": 5, + "vision_model_id": -1, } @@ -162,9 +162,9 @@ async def test_create_treats_unset_prefs_as_auto_zero( monkeypatch.setattr(AutomationService, "_get_with_triggers_or_raise", _return_added) search_space = SimpleNamespace( - agent_llm_id=None, - image_generation_config_id=None, - vision_llm_config_id=None, + chat_model_id=None, + image_gen_model_id=None, + vision_model_id=None, ) service = _service(search_space) payload = AutomationCreate(search_space_id=1, name="A", definition=_definition()) @@ -172,9 +172,9 @@ async def test_create_treats_unset_prefs_as_auto_zero( automation = await service.create(payload) assert automation.definition["models"] == { - "agent_llm_id": 0, - "image_generation_config_id": 0, - "vision_llm_config_id": 0, + "chat_model_id": 0, + "image_gen_model_id": 0, + "vision_model_id": 0, } @@ -195,11 +195,11 @@ async def test_create_honors_selected_models_when_provided( ) validated: dict[str, Any] = {} - def _assert_ok(*, agent_llm_id, image_generation_config_id, vision_llm_config_id): + def _assert_ok(*, chat_model_id, image_gen_model_id, vision_model_id): validated["ids"] = ( - agent_llm_id, - image_generation_config_id, - vision_llm_config_id, + chat_model_id, + image_gen_model_id, + vision_model_id, ) monkeypatch.setattr(automation_mod, "assert_models_billable", _assert_ok) @@ -213,15 +213,15 @@ async def test_create_honors_selected_models_when_provided( monkeypatch.setattr(AutomationService, "_authorize", _noop_authorize) monkeypatch.setattr(AutomationService, "_get_with_triggers_or_raise", _return_added) - service = _service(SimpleNamespace(agent_llm_id=-99)) + service = _service(SimpleNamespace(chat_model_id=-99)) payload = AutomationCreate( search_space_id=1, name="A", definition=_definition( models=AutomationModels( - agent_llm_id=-1, - image_generation_config_id=7, - vision_llm_config_id=-2, + chat_model_id=-1, + image_gen_model_id=7, + vision_model_id=-2, ) ), ) @@ -230,9 +230,9 @@ async def test_create_honors_selected_models_when_provided( assert validated["ids"] == (-1, 7, -2) assert automation.definition["models"] == { - "agent_llm_id": -1, - "image_generation_config_id": 7, - "vision_llm_config_id": -2, + "chat_model_id": -1, + "image_gen_model_id": 7, + "vision_model_id": -2, } @@ -241,9 +241,9 @@ async def test_create_rejects_unbillable_selected_models( ) -> None: """A non-billable explicit selection maps the policy error to HTTP 422.""" - def _raise(*, agent_llm_id, image_generation_config_id, vision_llm_config_id): + def _raise(*, chat_model_id, image_gen_model_id, vision_model_id): raise AutomationModelPolicyError( - [{"kind": "llm", "config_id": -3, "reason": "free model"}] + [{"kind": "llm", "model_id": -3, "reason": "free model"}] ) monkeypatch.setattr(automation_mod, "assert_models_billable", _raise) @@ -253,15 +253,15 @@ async def test_create_rejects_unbillable_selected_models( monkeypatch.setattr(AutomationService, "_authorize", _noop_authorize) - service = _service(SimpleNamespace(agent_llm_id=-3)) + service = _service(SimpleNamespace(chat_model_id=-3)) payload = AutomationCreate( search_space_id=1, name="A", definition=_definition( models=AutomationModels( - agent_llm_id=-3, - image_generation_config_id=7, - vision_llm_config_id=-2, + chat_model_id=-3, + image_gen_model_id=7, + vision_model_id=-2, ) ), ) @@ -277,9 +277,9 @@ async def test_update_preserves_captured_models( ) -> None: """A definition edit carries over the previously captured ``models``.""" captured = { - "agent_llm_id": -1, - "image_generation_config_id": 5, - "vision_llm_config_id": -1, + "chat_model_id": -1, + "image_gen_model_id": 5, + "vision_model_id": -1, } existing = SimpleNamespace( search_space_id=1, @@ -318,20 +318,20 @@ async def test_update_honors_changed_models_when_valid( "name": "A", "plan": [], "models": { - "agent_llm_id": -1, - "image_generation_config_id": 5, - "vision_llm_config_id": -1, + "chat_model_id": -1, + "image_gen_model_id": 5, + "vision_model_id": -1, }, }, version=3, ) validated: dict[str, Any] = {} - def _assert_ok(*, agent_llm_id, image_generation_config_id, vision_llm_config_id): + def _assert_ok(*, chat_model_id, image_gen_model_id, vision_model_id): validated["ids"] = ( - agent_llm_id, - image_generation_config_id, - vision_llm_config_id, + chat_model_id, + image_gen_model_id, + vision_model_id, ) monkeypatch.setattr(automation_mod, "assert_models_billable", _assert_ok) @@ -351,9 +351,9 @@ async def test_update_honors_changed_models_when_valid( patch = AutomationUpdate( definition=_definition( models=AutomationModels( - agent_llm_id=-2, - image_generation_config_id=9, - vision_llm_config_id=-2, + chat_model_id=-2, + image_gen_model_id=9, + vision_model_id=-2, ) ) ) @@ -362,9 +362,9 @@ async def test_update_honors_changed_models_when_valid( assert validated["ids"] == (-2, 9, -2) assert result.definition["models"] == { - "agent_llm_id": -2, - "image_generation_config_id": 9, - "vision_llm_config_id": -2, + "chat_model_id": -2, + "image_gen_model_id": 9, + "vision_model_id": -2, } assert result.version == 4 @@ -379,17 +379,17 @@ async def test_update_rejects_changed_unbillable_models( "name": "A", "plan": [], "models": { - "agent_llm_id": -1, - "image_generation_config_id": 5, - "vision_llm_config_id": -1, + "chat_model_id": -1, + "image_gen_model_id": 5, + "vision_model_id": -1, }, }, version=3, ) - def _raise(*, agent_llm_id, image_generation_config_id, vision_llm_config_id): + def _raise(*, chat_model_id, image_gen_model_id, vision_model_id): raise AutomationModelPolicyError( - [{"kind": "llm", "config_id": -7, "reason": "free model"}] + [{"kind": "llm", "model_id": -7, "reason": "free model"}] ) monkeypatch.setattr(automation_mod, "assert_models_billable", _raise) @@ -409,9 +409,9 @@ async def test_update_rejects_changed_unbillable_models( patch = AutomationUpdate( definition=_definition( models=AutomationModels( - agent_llm_id=-7, - image_generation_config_id=5, - vision_llm_config_id=-1, + chat_model_id=-7, + image_gen_model_id=5, + vision_model_id=-1, ) ) ) @@ -431,9 +431,9 @@ async def test_update_keeps_unchanged_models_without_revalidation( premium without an unrelated edit tripping the policy check. """ captured = { - "agent_llm_id": -1, - "image_generation_config_id": 5, - "vision_llm_config_id": -1, + "chat_model_id": -1, + "image_gen_model_id": 5, + "vision_model_id": -1, } existing = SimpleNamespace( search_space_id=1, @@ -485,7 +485,7 @@ async def test_model_eligibility_authorizes_and_returns_payload( lambda _ss: {"allowed": False, "violations": [{"kind": "image"}]}, ) - service = _service(SimpleNamespace(agent_llm_id=-2)) + service = _service(SimpleNamespace(chat_model_id=-2)) result = await service.model_eligibility(search_space_id=5) assert result == {"allowed": False, "violations": [{"kind": "image"}]} diff --git a/surfsense_backend/tests/unit/automations/services/test_model_policy.py b/surfsense_backend/tests/unit/automations/services/test_model_policy.py index 8e0806151..574f6d9fd 100644 --- a/surfsense_backend/tests/unit/automations/services/test_model_policy.py +++ b/surfsense_backend/tests/unit/automations/services/test_model_policy.py @@ -27,9 +27,9 @@ pytestmark = pytest.mark.unit def _search_space(*, llm: int | None, image: int | None, vision: int | None): """Minimal stand-in for the ``SearchSpace`` ORM row the policy reads.""" return SimpleNamespace( - agent_llm_id=llm, - image_generation_config_id=image, - vision_llm_config_id=vision, + chat_model_id=llm, + image_gen_model_id=image, + vision_model_id=vision, ) @@ -39,29 +39,11 @@ def patched_globals(monkeypatch: pytest.MonkeyPatch): Negative ids: -1 is premium, -2 is free, for each of llm/image/vision. """ - llm_configs = { - -1: {"id": -1, "billing_tier": "premium"}, - -2: {"id": -2, "billing_tier": "free"}, - } - monkeypatch.setattr( - "app.agents.chat.runtime.llm_config.load_global_llm_config_by_id", - lambda cid: llm_configs.get(cid), - ) - from app.config import config as app_config monkeypatch.setattr( app_config, - "GLOBAL_IMAGE_GEN_CONFIGS", - [ - {"id": -1, "billing_tier": "premium"}, - {"id": -2, "billing_tier": "free"}, - ], - raising=False, - ) - monkeypatch.setattr( - app_config, - "GLOBAL_VISION_LLM_CONFIGS", + "GLOBAL_MODELS", [ {"id": -1, "billing_tier": "premium"}, {"id": -2, "billing_tier": "free"}, @@ -71,7 +53,7 @@ def patched_globals(monkeypatch: pytest.MonkeyPatch): return None -@pytest.mark.parametrize("kind", ["llm", "image", "vision"]) +@pytest.mark.parametrize("kind", ["chat", "image", "vision"]) def test_byok_positive_id_is_allowed(kind: str, patched_globals) -> None: """A positive config id is a user-owned BYOK model — always billable.""" allowed, reason = model_policy._classify(kind, 7) @@ -79,7 +61,7 @@ def test_byok_positive_id_is_allowed(kind: str, patched_globals) -> None: assert reason == "" -@pytest.mark.parametrize("kind", ["llm", "image", "vision"]) +@pytest.mark.parametrize("kind", ["chat", "image", "vision"]) @pytest.mark.parametrize("config_id", [0, None]) def test_auto_mode_is_blocked(kind: str, config_id, patched_globals) -> None: """Auto mode (id 0) and an unset slot (None) are blocked.""" @@ -88,7 +70,7 @@ def test_auto_mode_is_blocked(kind: str, config_id, patched_globals) -> None: assert "Auto mode" in reason -@pytest.mark.parametrize("kind", ["llm", "image", "vision"]) +@pytest.mark.parametrize("kind", ["chat", "image", "vision"]) def test_premium_global_is_allowed(kind: str, patched_globals) -> None: """A negative (global) id with premium billing tier is allowed.""" allowed, reason = model_policy._classify(kind, -1) @@ -96,7 +78,7 @@ def test_premium_global_is_allowed(kind: str, patched_globals) -> None: assert reason == "" -@pytest.mark.parametrize("kind", ["llm", "image", "vision"]) +@pytest.mark.parametrize("kind", ["chat", "image", "vision"]) def test_free_global_is_blocked(kind: str, patched_globals) -> None: """A negative (global) id with a free billing tier is blocked.""" allowed, reason = model_policy._classify(kind, -2) @@ -104,7 +86,7 @@ def test_free_global_is_blocked(kind: str, patched_globals) -> None: assert "free model" in reason -@pytest.mark.parametrize("kind", ["llm", "image", "vision"]) +@pytest.mark.parametrize("kind", ["chat", "image", "vision"]) def test_unknown_global_id_is_blocked(kind: str, patched_globals) -> None: """A negative id that resolves to no config is treated as not premium.""" allowed, _ = model_policy._classify(kind, -999) @@ -125,10 +107,10 @@ def test_eligibility_reports_each_violation(patched_globals) -> None: assert result["allowed"] is False kinds = {v["kind"] for v in result["violations"]} - assert kinds == {"llm", "image", "vision"} - # config_id is echoed back for the UI / settings deep-link. - by_kind = {v["kind"]: v["config_id"] for v in result["violations"]} - assert by_kind == {"llm": -2, "image": 0, "vision": -2} + assert kinds == {"chat", "image", "vision"} + # model_id is echoed back for the UI / settings deep-link. + by_kind = {v["kind"]: v["model_id"] for v in result["violations"]} + assert by_kind == {"chat": -2, "image": 0, "vision": -2} def test_assert_raises_with_violations(patched_globals) -> None: @@ -138,7 +120,7 @@ def test_assert_raises_with_violations(patched_globals) -> None: assert_automation_models_billable(search_space) assert len(exc_info.value.violations) == 1 - assert exc_info.value.violations[0]["kind"] == "llm" + assert exc_info.value.violations[0]["kind"] == "chat" def test_assert_passes_when_all_billable(patched_globals) -> None: @@ -153,7 +135,7 @@ def test_assert_passes_when_all_billable(patched_globals) -> None: def test_get_model_eligibility_all_billable(patched_globals) -> None: """Premium LLM + BYOK image + premium vision (explicit ids) → allowed.""" result = get_model_eligibility( - agent_llm_id=-1, image_generation_config_id=5, vision_llm_config_id=-1 + chat_model_id=-1, image_gen_model_id=5, vision_model_id=-1 ) assert result == {"allowed": True, "violations": []} @@ -161,28 +143,28 @@ def test_get_model_eligibility_all_billable(patched_globals) -> None: def test_get_model_eligibility_reports_each_violation(patched_globals) -> None: """Free LLM, Auto image, free vision (explicit ids) each produce a violation.""" result = get_model_eligibility( - agent_llm_id=-2, image_generation_config_id=0, vision_llm_config_id=-2 + chat_model_id=-2, image_gen_model_id=0, vision_model_id=-2 ) assert result["allowed"] is False - by_kind = {v["kind"]: v["config_id"] for v in result["violations"]} - assert by_kind == {"llm": -2, "image": 0, "vision": -2} + by_kind = {v["kind"]: v["model_id"] for v in result["violations"]} + assert by_kind == {"chat": -2, "image": 0, "vision": -2} def test_assert_models_billable_raises(patched_globals) -> None: """``assert_models_billable`` raises when any explicit id is blocked.""" with pytest.raises(AutomationModelPolicyError) as exc_info: assert_models_billable( - agent_llm_id=0, image_generation_config_id=5, vision_llm_config_id=-1 + chat_model_id=0, image_gen_model_id=5, vision_model_id=-1 ) assert len(exc_info.value.violations) == 1 - assert exc_info.value.violations[0]["kind"] == "llm" + assert exc_info.value.violations[0]["kind"] == "chat" def test_assert_models_billable_passes(patched_globals) -> None: """No exception when every explicit id is premium or BYOK.""" assert ( assert_models_billable( - agent_llm_id=3, image_generation_config_id=-1, vision_llm_config_id=4 + chat_model_id=3, image_gen_model_id=-1, vision_model_id=4 ) is None ) @@ -192,5 +174,5 @@ def test_search_space_wrapper_delegates_to_core(patched_globals) -> None: """The search-space wrapper produces the same result as the ID core.""" search_space = _search_space(llm=-2, image=0, vision=-2) assert get_automation_model_eligibility(search_space) == get_model_eligibility( - agent_llm_id=-2, image_generation_config_id=0, vision_llm_config_id=-2 + chat_model_id=-2, image_gen_model_id=0, vision_model_id=-2 ) diff --git a/surfsense_backend/tests/unit/routes/test_byok_supports_image_input.py b/surfsense_backend/tests/unit/routes/test_byok_supports_image_input.py deleted file mode 100644 index c9f18d77d..000000000 --- a/surfsense_backend/tests/unit/routes/test_byok_supports_image_input.py +++ /dev/null @@ -1,110 +0,0 @@ -"""Unit tests for ``supports_image_input`` derivation on BYOK chat config -endpoints (``GET /new-llm-configs`` list, ``GET /new-llm-configs/{id}``). - -There is no DB column for ``supports_image_input`` on -``NewLLMConfig`` — the value is resolved at the API boundary by -``derive_supports_image_input`` so the new-chat selector / streaming -task can read the same field shape regardless of source (BYOK vs YAML -vs OpenRouter dynamic). Default-allow on unknown so we don't lock the -user out of their own model choice. -""" - -from __future__ import annotations - -from datetime import UTC, datetime -from types import SimpleNamespace -from uuid import uuid4 - -import pytest - -from app.db import LiteLLMProvider -from app.routes import new_llm_config_routes - -pytestmark = pytest.mark.unit - - -def _byok_row( - *, - id_: int, - model_name: str, - base_model: str | None = None, - provider: LiteLLMProvider = LiteLLMProvider.OPENAI, - custom_provider: str | None = None, -) -> object: - """Mimic the SQLAlchemy row's attribute surface; ``model_validate`` - walks ``from_attributes=True`` so a ``SimpleNamespace`` is enough. - - ``provider`` is a real ``LiteLLMProvider`` enum value so Pydantic's - enum validator accepts it — same as the ORM row would carry.""" - return SimpleNamespace( - id=id_, - name=f"BYOK-{id_}", - description=None, - provider=provider, - custom_provider=custom_provider, - model_name=model_name, - api_key="sk-byok", - api_base=None, - litellm_params={"base_model": base_model} if base_model else None, - system_instructions="", - use_default_system_instructions=True, - citations_enabled=True, - created_at=datetime.now(tz=UTC), - search_space_id=42, - user_id=uuid4(), - ) - - -def test_serialize_byok_known_vision_model_resolves_true(): - """The catalog resolver consults LiteLLM's map for ``gpt-4o`` -> - True. The serialized row carries that value through to the - ``NewLLMConfigRead`` schema.""" - row = _byok_row(id_=1, model_name="gpt-4o") - serialized = new_llm_config_routes._serialize_byok_config(row) - - assert serialized.supports_image_input is True - assert serialized.id == 1 - assert serialized.model_name == "gpt-4o" - - -def test_serialize_byok_unknown_model_default_allows(): - """Unknown / unmapped: default-allow. The streaming-task safety net - is the actual block, and it requires LiteLLM to *explicitly* say - text-only — so a brand new BYOK model should not be pre-judged.""" - row = _byok_row( - id_=2, - model_name="brand-new-model-x9-unmapped", - provider=LiteLLMProvider.CUSTOM, - custom_provider="brand_new_proxy", - ) - serialized = new_llm_config_routes._serialize_byok_config(row) - - assert serialized.supports_image_input is True - - -def test_serialize_byok_uses_base_model_when_present(): - """Azure-style: ``model_name`` is the deployment id, ``base_model`` - inside ``litellm_params`` is the canonical sku LiteLLM knows. The - helper must consult ``base_model`` first or unrecognised deployment - ids would shadow the real capability.""" - row = _byok_row( - id_=3, - model_name="my-azure-deployment-id-no-litellm-knows-this", - base_model="gpt-4o", - provider=LiteLLMProvider.AZURE_OPENAI, - ) - serialized = new_llm_config_routes._serialize_byok_config(row) - - assert serialized.supports_image_input is True - - -def test_serialize_byok_returns_pydantic_read_model(): - """The route now returns ``NewLLMConfigRead`` (not the raw ORM) so - the schema additions are guaranteed to be present in the API - surface. This guards against a future regression where someone - deletes the augmentation step and falls back to ORM passthrough.""" - from app.schemas import NewLLMConfigRead - - row = _byok_row(id_=4, model_name="gpt-4o") - serialized = new_llm_config_routes._serialize_byok_config(row) - assert isinstance(serialized, NewLLMConfigRead) diff --git a/surfsense_backend/tests/unit/routes/test_global_configs_is_premium.py b/surfsense_backend/tests/unit/routes/test_global_configs_is_premium.py deleted file mode 100644 index fff61f14b..000000000 --- a/surfsense_backend/tests/unit/routes/test_global_configs_is_premium.py +++ /dev/null @@ -1,184 +0,0 @@ -"""Unit tests for ``is_premium`` derivation on the global image-gen and -vision-LLM list endpoints. - -Chat globals (``GET /global-llm-configs``) already emit -``is_premium = (billing_tier == "premium")``. Image and vision did not, -which made the new-chat ``model-selector`` render the Free/Premium badge -on the Chat tab but skip it on the Image and Vision tabs (the selector -keys its badge logic off ``is_premium``). These tests pin parity: - -* YAML free entry → ``is_premium=False`` -* YAML premium entry → ``is_premium=True`` -* OpenRouter dynamic premium entry → ``is_premium=True`` -* Auto stub (always emitted when at least one config is present) - → ``is_premium=False`` -""" - -from __future__ import annotations - -import pytest - -pytestmark = pytest.mark.unit - - -_IMAGE_FIXTURE: list[dict] = [ - { - "id": -1, - "name": "DALL-E 3", - "litellm_provider": "openai", - "model_name": "dall-e-3", - "api_key": "sk-test", - "billing_tier": "free", - }, - { - "id": -2, - "name": "GPT-Image 1 (premium)", - "litellm_provider": "openai", - "model_name": "gpt-image-1", - "api_key": "sk-test", - "billing_tier": "premium", - }, - { - "id": -20_001, - "name": "google/gemini-2.5-flash-image (OpenRouter)", - "litellm_provider": "openrouter", - "model_name": "google/gemini-2.5-flash-image", - "api_key": "sk-or-test", - "api_base": "https://openrouter.ai/api/v1", - "billing_tier": "premium", - }, -] - - -_VISION_FIXTURE: list[dict] = [ - { - "id": -1, - "name": "GPT-4o Vision", - "litellm_provider": "openai", - "model_name": "gpt-4o", - "api_key": "sk-test", - "billing_tier": "free", - }, - { - "id": -2, - "name": "Claude 3.5 Sonnet (premium)", - "litellm_provider": "anthropic", - "model_name": "claude-3-5-sonnet", - "api_key": "sk-ant-test", - "billing_tier": "premium", - }, - { - "id": -30_001, - "name": "openai/gpt-4o (OpenRouter)", - "litellm_provider": "openrouter", - "model_name": "openai/gpt-4o", - "api_key": "sk-or-test", - "api_base": "https://openrouter.ai/api/v1", - "billing_tier": "premium", - }, -] - - -# ============================================================================= -# Image generation -# ============================================================================= - - -@pytest.mark.asyncio -async def test_global_image_gen_configs_emit_is_premium(monkeypatch): - """Each emitted config must carry ``is_premium`` derived server-side - from ``billing_tier``. The Auto stub is always free. - """ - from app.config import config - from app.routes import image_generation_routes - - monkeypatch.setattr( - config, "GLOBAL_IMAGE_GEN_CONFIGS", _IMAGE_FIXTURE, raising=False - ) - - payload = await image_generation_routes.get_global_image_gen_configs(user=None) - - by_id = {c["id"]: c for c in payload} - - # Auto stub is always emitted when at least one global config exists, - # and it must always declare itself free (Auto-mode billing-tier - # surfacing is a separate follow-up). - assert 0 in by_id, "Auto stub should be emitted when at least one config exists" - assert by_id[0]["is_premium"] is False - assert by_id[0]["billing_tier"] == "free" - - # YAML free entry — ``is_premium=False`` - assert by_id[-1]["is_premium"] is False - assert by_id[-1]["billing_tier"] == "free" - - # YAML premium entry — ``is_premium=True`` - assert by_id[-2]["is_premium"] is True - assert by_id[-2]["billing_tier"] == "premium" - - # OpenRouter dynamic premium entry — same field, same derivation - assert by_id[-20_001]["is_premium"] is True - assert by_id[-20_001]["billing_tier"] == "premium" - - # Every emitted dict (including Auto) must have the field — never missing. - for cfg in payload: - assert "is_premium" in cfg, f"is_premium missing from {cfg.get('id')}" - assert isinstance(cfg["is_premium"], bool) - - -@pytest.mark.asyncio -async def test_global_image_gen_configs_no_globals_no_auto_stub(monkeypatch): - """When there are no global configs at all, the endpoint emits an - empty list (no Auto stub) — Auto mode would have nothing to route to. - """ - from app.config import config - from app.routes import image_generation_routes - - monkeypatch.setattr(config, "GLOBAL_IMAGE_GEN_CONFIGS", [], raising=False) - payload = await image_generation_routes.get_global_image_gen_configs(user=None) - assert payload == [] - - -# ============================================================================= -# Vision LLM -# ============================================================================= - - -@pytest.mark.asyncio -async def test_global_vision_llm_configs_emit_is_premium(monkeypatch): - from app.config import config - from app.routes import vision_llm_routes - - monkeypatch.setattr( - config, "GLOBAL_VISION_LLM_CONFIGS", _VISION_FIXTURE, raising=False - ) - - payload = await vision_llm_routes.get_global_vision_llm_configs(user=None) - - by_id = {c["id"]: c for c in payload} - - assert 0 in by_id, "Auto stub should be emitted when at least one config exists" - assert by_id[0]["is_premium"] is False - assert by_id[0]["billing_tier"] == "free" - - assert by_id[-1]["is_premium"] is False - assert by_id[-1]["billing_tier"] == "free" - - assert by_id[-2]["is_premium"] is True - assert by_id[-2]["billing_tier"] == "premium" - - assert by_id[-30_001]["is_premium"] is True - assert by_id[-30_001]["billing_tier"] == "premium" - - for cfg in payload: - assert "is_premium" in cfg, f"is_premium missing from {cfg.get('id')}" - assert isinstance(cfg["is_premium"], bool) - - -@pytest.mark.asyncio -async def test_global_vision_llm_configs_no_globals_no_auto_stub(monkeypatch): - from app.config import config - from app.routes import vision_llm_routes - - monkeypatch.setattr(config, "GLOBAL_VISION_LLM_CONFIGS", [], raising=False) - payload = await vision_llm_routes.get_global_vision_llm_configs(user=None) - assert payload == [] diff --git a/surfsense_backend/tests/unit/routes/test_global_new_llm_configs_supports_image.py b/surfsense_backend/tests/unit/routes/test_global_new_llm_configs_supports_image.py deleted file mode 100644 index 67d1112f3..000000000 --- a/surfsense_backend/tests/unit/routes/test_global_new_llm_configs_supports_image.py +++ /dev/null @@ -1,106 +0,0 @@ -"""Unit tests for ``supports_image_input`` derivation on the chat global -config endpoint (``GET /global-new-llm-configs``). - -Resolution order (matches ``new_llm_config_routes.get_global_new_llm_configs``): - -1. Explicit ``supports_image_input`` on the cfg dict (set by the YAML - loader for operator overrides, or by the OpenRouter integration from - ``architecture.input_modalities``) — wins. -2. ``derive_supports_image_input`` helper — default-allow on unknown - models, only False when LiteLLM / OR modalities are definitive. - -The flag is purely informational at the API boundary. The streaming -task safety net (``is_known_text_only_chat_model``) is the actual block, -and it requires LiteLLM to *explicitly* mark the model as text-only. -""" - -from __future__ import annotations - -import pytest - -pytestmark = pytest.mark.unit - - -_FIXTURE: list[dict] = [ - { - "id": -1, - "name": "GPT-4o (explicit true)", - "description": "vision-capable, explicit YAML override", - "litellm_provider": "openai", - "model_name": "gpt-4o", - "api_key": "sk-test", - "billing_tier": "free", - "supports_image_input": True, - }, - { - "id": -2, - "name": "DeepSeek V3 (explicit false)", - "description": "OpenRouter dynamic — modality-derived false", - "litellm_provider": "openrouter", - "model_name": "deepseek/deepseek-v3.2-exp", - "api_key": "sk-or-test", - "api_base": "https://openrouter.ai/api/v1", - "billing_tier": "free", - "supports_image_input": False, - }, - { - "id": -10_010, - "name": "Unannotated GPT-4o", - "description": "no flag set — resolver should derive True via LiteLLM", - "litellm_provider": "openai", - "model_name": "gpt-4o", - "api_key": "sk-test", - "billing_tier": "free", - # supports_image_input intentionally absent - }, - { - "id": -10_011, - "name": "Unannotated unknown model", - "description": "unmapped — default-allow True", - "litellm_provider": "custom", - "custom_provider": "brand_new_proxy", - "model_name": "brand-new-model-x9", - "api_key": "sk-test", - "billing_tier": "free", - }, -] - - -@pytest.mark.asyncio -async def test_global_new_llm_configs_emit_supports_image_input(monkeypatch): - """Each emitted chat config carries ``supports_image_input`` as a - bool. Explicit values win; unannotated entries are resolved via the - helper (default-allow True).""" - from app.config import config - from app.routes import new_llm_config_routes - - monkeypatch.setattr(config, "GLOBAL_LLM_CONFIGS", _FIXTURE, raising=False) - - payload = await new_llm_config_routes.get_global_new_llm_configs(user=None) - by_id = {c["id"]: c for c in payload} - - # Auto stub: optimistic True so the user can keep Auto selected with - # vision-capable deployments somewhere in the pool. - assert 0 in by_id, "Auto stub should be emitted when configs exist" - assert by_id[0]["supports_image_input"] is True - assert by_id[0]["is_auto_mode"] is True - - # Explicit True is preserved. - assert by_id[-1]["supports_image_input"] is True - - # Explicit False is preserved (the exact failure mode the safety net - # guards against — DeepSeek V3 over OpenRouter would 404 with "No - # endpoints found that support image input"). - assert by_id[-2]["supports_image_input"] is False - - # Unannotated GPT-4o: resolver consults LiteLLM, which says vision. - assert by_id[-10_010]["supports_image_input"] is True - - # Unknown / unmapped model: default-allow rather than pre-judge. - assert by_id[-10_011]["supports_image_input"] is True - - for cfg in payload: - assert "supports_image_input" in cfg, ( - f"supports_image_input missing from {cfg.get('id')}" - ) - assert isinstance(cfg["supports_image_input"], bool) diff --git a/surfsense_backend/tests/unit/routes/test_image_gen_quota.py b/surfsense_backend/tests/unit/routes/test_image_gen_quota.py index 53c0f50a9..4dd918927 100644 --- a/surfsense_backend/tests/unit/routes/test_image_gen_quota.py +++ b/surfsense_backend/tests/unit/routes/test_image_gen_quota.py @@ -27,9 +27,18 @@ async def test_resolve_billing_for_auto_mode(monkeypatch): from app.routes import image_generation_routes from app.services.billable_calls import DEFAULT_IMAGE_RESERVE_MICROS - search_space = SimpleNamespace(image_generation_config_id=None) + async def _no_auto_candidates(*_args, **_kwargs): + return [] + + monkeypatch.setattr( + image_generation_routes, + "auto_model_candidates", + _no_auto_candidates, + ) + + search_space = SimpleNamespace(id=1, user_id=None, image_gen_model_id=None) tier, model, reserve = await image_generation_routes._resolve_billing_for_image_gen( - session=None, # Not consumed on this code path. + session=None, config_id=0, # IMAGE_GEN_AUTO_MODE_ID search_space=search_space, ) @@ -45,26 +54,42 @@ async def test_resolve_billing_for_premium_global_config(monkeypatch): monkeypatch.setattr( config, - "GLOBAL_IMAGE_GEN_CONFIGS", + "GLOBAL_MODELS", [ { "id": -1, - "litellm_provider": "openai", - "model_name": "gpt-image-1", + "connection_id": -101, + "model_id": "gpt-image-1", "billing_tier": "premium", - "quota_reserve_micros": 75_000, + "catalog": {"quota_reserve_micros": 75_000}, }, { "id": -2, - "litellm_provider": "openrouter", - "model_name": "google/gemini-2.5-flash-image", + "connection_id": -102, + "model_id": "google/gemini-2.5-flash-image", "billing_tier": "free", + "catalog": {}, + }, + ], + raising=False, + ) + monkeypatch.setattr( + config, + "GLOBAL_CONNECTIONS", + [ + {"id": -101, "provider": "openai", "api_key": "sk-test", "base_url": None, "extra": {}}, + { + "id": -102, + "provider": "openrouter", + "api_key": "sk-or-test", + "base_url": "https://openrouter.ai/api/v1", + "extra": {}, }, ], raising=False, ) - search_space = SimpleNamespace(image_generation_config_id=None) + search_space = SimpleNamespace(id=1, user_id=None, image_gen_model_id=None) # Premium with override. tier, model, reserve = await image_generation_routes._resolve_billing_for_image_gen( @@ -94,7 +119,7 @@ async def test_resolve_billing_for_user_owned_byok_is_free(): from app.routes import image_generation_routes from app.services.billable_calls import DEFAULT_IMAGE_RESERVE_MICROS - search_space = SimpleNamespace(image_generation_config_id=None) + search_space = SimpleNamespace(id=1, user_id=None, image_gen_model_id=None) tier, model, reserve = await image_generation_routes._resolve_billing_for_image_gen( session=None, config_id=42, search_space=search_space ) @@ -105,7 +130,7 @@ async def test_resolve_billing_for_user_owned_byok_is_free(): @pytest.mark.asyncio async def test_resolve_billing_falls_back_to_search_space_default(monkeypatch): - """When the request omits ``image_generation_config_id``, the helper + """When the request omits ``image_gen_model_id``, the helper must consult the search space's default — so a search space pinned to a premium global config still gates new requests by quota. """ @@ -114,19 +139,26 @@ async def test_resolve_billing_falls_back_to_search_space_default(monkeypatch): monkeypatch.setattr( config, - "GLOBAL_IMAGE_GEN_CONFIGS", + "GLOBAL_MODELS", [ { "id": -7, - "litellm_provider": "openai", - "model_name": "gpt-image-1", + "connection_id": -101, + "model_id": "gpt-image-1", "billing_tier": "premium", + "catalog": {}, } ], raising=False, ) + monkeypatch.setattr( + config, + "GLOBAL_CONNECTIONS", + [{"id": -101, "provider": "openai", "api_key": "sk-test", "base_url": None, "extra": {}}], + raising=False, + ) - search_space = SimpleNamespace(image_generation_config_id=-7) + search_space = SimpleNamespace(id=1, user_id=None, image_gen_model_id=-7) ( tier, model, diff --git a/surfsense_backend/tests/unit/services/test_agent_billing_resolver.py b/surfsense_backend/tests/unit/services/test_agent_billing_resolver.py index fa8819b39..b43540ba7 100644 --- a/surfsense_backend/tests/unit/services/test_agent_billing_resolver.py +++ b/surfsense_backend/tests/unit/services/test_agent_billing_resolver.py @@ -1,27 +1,4 @@ -"""Unit tests for ``_resolve_agent_billing_for_search_space``. - -Validates the resolver used by Celery podcast/video tasks to compute -``(owner_user_id, billing_tier, base_model)`` from a search space and its -agent LLM config. The resolver mirrors chat's billing-resolution pattern at -``stream_new_chat.py:2294-2351`` and is the single integration point that -prevents Auto-mode podcast/video from leaking premium credit. - -Coverage: - -* Auto mode + ``thread_id`` set, pin resolves to a negative-id premium - global → returns ``("premium", )``. -* Auto mode + ``thread_id`` set, pin resolves to a negative-id free - global → returns ``("free", )``. -* Auto mode + ``thread_id`` set, pin resolves to a positive-id BYOK config - → always ``"free"``. -* Auto mode + ``thread_id=None`` → fallback to ``("free", "auto")`` without - hitting the pin service. -* Negative id (no Auto) → uses ``get_global_llm_config``'s - ``billing_tier``. -* Positive id (user BYOK) → always ``"free"``. -* Search space not found → raises ``ValueError``. -* ``agent_llm_id`` is None → raises ``ValueError``. -""" +"""Unit tests for ``_resolve_agent_billing_for_search_space``.""" from __future__ import annotations @@ -34,11 +11,6 @@ import pytest pytestmark = pytest.mark.unit -# --------------------------------------------------------------------------- -# Fakes -# --------------------------------------------------------------------------- - - class _FakeExecResult: def __init__(self, obj): self._obj = obj @@ -51,14 +23,6 @@ class _FakeExecResult: class _FakeSession: - """Tiny AsyncSession stub. - - ``responses`` is a list of objects to return from successive - ``execute()`` calls (in order). The resolver makes at most two - ``execute()`` calls (search-space lookup, then optionally NewLLMConfig - lookup), so two queued responses cover the matrix. - """ - def __init__(self, responses: list): self._responses = list(responses) @@ -67,9 +31,6 @@ class _FakeSession: return _FakeExecResult(None) return _FakeExecResult(self._responses.pop(0)) - async def commit(self) -> None: - pass - @dataclass class _FakePinResolution: @@ -78,53 +39,33 @@ class _FakePinResolution: from_existing_pin: bool = False -def _make_search_space(*, agent_llm_id: int | None, user_id: UUID) -> SimpleNamespace: - return SimpleNamespace( - id=42, - agent_llm_id=agent_llm_id, - user_id=user_id, - ) +def _make_search_space(*, chat_model_id: int | None, user_id: UUID) -> SimpleNamespace: + return SimpleNamespace(id=42, chat_model_id=chat_model_id, user_id=user_id) -def _make_byok_config( - *, id_: int, base_model: str | None = None, model_name: str = "gpt-byok" +def _make_byok_model( + *, id_: int, base_model: str | None = None, model_id: str = "gpt-byok" ) -> SimpleNamespace: return SimpleNamespace( id=id_, - model_name=model_name, - litellm_params={"base_model": base_model} if base_model else {}, + model_id=model_id, + catalog={"base_model": base_model} if base_model else {}, + connection=SimpleNamespace(enabled=True, search_space_id=42, user_id=None), ) -# --------------------------------------------------------------------------- -# Tests -# --------------------------------------------------------------------------- - - @pytest.mark.asyncio async def test_auto_mode_with_thread_id_resolves_to_premium_global(monkeypatch): - """Auto + thread → pin service resolves to negative-id premium config → - resolver returns ``("premium", )``.""" from app.services.billable_calls import _resolve_agent_billing_for_search_space user_id = uuid4() - session = _FakeSession([_make_search_space(agent_llm_id=0, user_id=user_id)]) + session = _FakeSession([_make_search_space(chat_model_id=0, user_id=user_id)]) - # Mock the pin service to return a concrete premium config id. - async def _fake_resolve_pin( - sess, - *, - thread_id, - search_space_id, - user_id, - selected_llm_config_id, - force_repin_free=False, - ): - assert selected_llm_config_id == 0 - assert thread_id == 99 + async def _fake_resolve_pin(*_args, **kwargs): + assert kwargs["selected_llm_config_id"] == 0 + assert kwargs["thread_id"] == 99 return _FakePinResolution(resolved_llm_config_id=-1, resolved_tier="premium") - # Mock global config lookup to return a premium entry. def _fake_get_global(cfg_id): if cfg_id == -1: return { @@ -135,8 +76,6 @@ async def test_auto_mode_with_thread_id_resolves_to_premium_global(monkeypatch): } return None - # Lazy imports inside the resolver — patch the *target* modules so the - # imported names resolve to our fakes. import app.services.auto_model_pin_service as pin_module import app.services.llm_service as llm_module @@ -154,77 +93,18 @@ async def test_auto_mode_with_thread_id_resolves_to_premium_global(monkeypatch): assert base_model == "gpt-5.4" -@pytest.mark.asyncio -async def test_auto_mode_with_thread_id_resolves_to_free_global(monkeypatch): - """Auto + thread → pin returns negative-id free config → resolver - returns ``("free", )``. Same path the pin service takes for - out-of-credit users (graceful degradation).""" - from app.services.billable_calls import _resolve_agent_billing_for_search_space - - user_id = uuid4() - session = _FakeSession([_make_search_space(agent_llm_id=0, user_id=user_id)]) - - async def _fake_resolve_pin( - sess, - *, - thread_id, - search_space_id, - user_id, - selected_llm_config_id, - force_repin_free=False, - ): - return _FakePinResolution(resolved_llm_config_id=-3, resolved_tier="free") - - def _fake_get_global(cfg_id): - if cfg_id == -3: - return { - "id": -3, - "model_name": "openrouter/free-model", - "billing_tier": "free", - "litellm_params": {"base_model": "openrouter/free-model"}, - } - return None - - import app.services.auto_model_pin_service as pin_module - import app.services.llm_service as llm_module - - monkeypatch.setattr( - pin_module, "resolve_or_get_pinned_llm_config_id", _fake_resolve_pin - ) - monkeypatch.setattr(llm_module, "get_global_llm_config", _fake_get_global) - - owner, tier, base_model = await _resolve_agent_billing_for_search_space( - session, search_space_id=42, thread_id=99 - ) - - assert owner == user_id - assert tier == "free" - assert base_model == "openrouter/free-model" - - @pytest.mark.asyncio async def test_auto_mode_with_thread_id_resolves_to_byok_is_free(monkeypatch): - """Auto + thread → pin returns positive-id BYOK config → resolver - returns ``("free", ...)`` (BYOK is always free per - ``AgentConfig.from_new_llm_config``).""" from app.services.billable_calls import _resolve_agent_billing_for_search_space user_id = uuid4() - search_space = _make_search_space(agent_llm_id=0, user_id=user_id) - byok_cfg = _make_byok_config( - id_=17, base_model="anthropic/claude-3-haiku", model_name="my-claude" + search_space = _make_search_space(chat_model_id=0, user_id=user_id) + byok_model = _make_byok_model( + id_=17, base_model="anthropic/claude-3-haiku", model_id="my-claude" ) - session = _FakeSession([search_space, byok_cfg]) + session = _FakeSession([search_space, byok_model]) - async def _fake_resolve_pin( - sess, - *, - thread_id, - search_space_id, - user_id, - selected_llm_config_id, - force_repin_free=False, - ): + async def _fake_resolve_pin(*_args, **_kwargs): return _FakePinResolution(resolved_llm_config_id=17, resolved_tier="free") import app.services.auto_model_pin_service as pin_module @@ -244,13 +124,10 @@ async def test_auto_mode_with_thread_id_resolves_to_byok_is_free(monkeypatch): @pytest.mark.asyncio async def test_auto_mode_without_thread_id_falls_back_to_free(): - """Auto + ``thread_id=None`` → ``("free", "auto")`` without invoking - the pin service. Forward-compat fallback for any future direct-API - entrypoint that doesn't have a chat thread.""" from app.services.billable_calls import _resolve_agent_billing_for_search_space user_id = uuid4() - session = _FakeSession([_make_search_space(agent_llm_id=0, user_id=user_id)]) + session = _FakeSession([_make_search_space(chat_model_id=0, user_id=user_id)]) owner, tier, base_model = await _resolve_agent_billing_for_search_space( session, search_space_id=42, thread_id=None @@ -263,13 +140,10 @@ async def test_auto_mode_without_thread_id_falls_back_to_free(): @pytest.mark.asyncio async def test_auto_mode_pin_failure_falls_back_to_free(monkeypatch): - """If the pin service raises ``ValueError`` (thread missing / - mismatched search space), the resolver should log and return free - rather than killing the whole task.""" from app.services.billable_calls import _resolve_agent_billing_for_search_space user_id = uuid4() - session = _FakeSession([_make_search_space(agent_llm_id=0, user_id=user_id)]) + session = _FakeSession([_make_search_space(chat_model_id=0, user_id=user_id)]) async def _fake_resolve_pin(*args, **kwargs): raise ValueError("thread missing") @@ -291,12 +165,10 @@ async def test_auto_mode_pin_failure_falls_back_to_free(monkeypatch): @pytest.mark.asyncio async def test_negative_id_premium_global_returns_premium(monkeypatch): - """Explicit negative agent_llm_id → ``get_global_llm_config`` → - return its ``billing_tier``.""" from app.services.billable_calls import _resolve_agent_billing_for_search_space user_id = uuid4() - session = _FakeSession([_make_search_space(agent_llm_id=-1, user_id=user_id)]) + session = _FakeSession([_make_search_space(chat_model_id=-1, user_id=user_id)]) def _fake_get_global(cfg_id): return { @@ -319,50 +191,15 @@ async def test_negative_id_premium_global_returns_premium(monkeypatch): assert base_model == "gpt-5.4" -@pytest.mark.asyncio -async def test_negative_id_free_global_returns_free(monkeypatch): - from app.services.billable_calls import _resolve_agent_billing_for_search_space - - user_id = uuid4() - session = _FakeSession([_make_search_space(agent_llm_id=-2, user_id=user_id)]) - - def _fake_get_global(cfg_id): - return { - "id": cfg_id, - "model_name": "openrouter/some-free", - "billing_tier": "free", - "litellm_params": {"base_model": "openrouter/some-free"}, - } - - import app.services.llm_service as llm_module - - monkeypatch.setattr(llm_module, "get_global_llm_config", _fake_get_global) - - owner, tier, base_model = await _resolve_agent_billing_for_search_space( - session, search_space_id=42, thread_id=None - ) - - assert owner == user_id - assert tier == "free" - assert base_model == "openrouter/some-free" - - @pytest.mark.asyncio async def test_negative_id_missing_base_model_falls_back_to_model_name(monkeypatch): - """When the global config has no ``litellm_params.base_model``, the - resolver falls back to ``model_name`` — matching chat's behavior.""" from app.services.billable_calls import _resolve_agent_billing_for_search_space user_id = uuid4() - session = _FakeSession([_make_search_space(agent_llm_id=-5, user_id=user_id)]) + session = _FakeSession([_make_search_space(chat_model_id=-5, user_id=user_id)]) def _fake_get_global(cfg_id): - return { - "id": cfg_id, - "model_name": "fallback-model", - "billing_tier": "premium", - # No litellm_params. - } + return {"id": cfg_id, "model_name": "fallback-model", "billing_tier": "premium"} import app.services.llm_service as llm_module @@ -378,14 +215,12 @@ async def test_negative_id_missing_base_model_falls_back_to_model_name(monkeypat @pytest.mark.asyncio async def test_positive_id_byok_is_always_free(): - """Positive agent_llm_id → user-owned BYOK NewLLMConfig → always free, - regardless of underlying provider tier.""" from app.services.billable_calls import _resolve_agent_billing_for_search_space user_id = uuid4() - search_space = _make_search_space(agent_llm_id=23, user_id=user_id) - byok_cfg = _make_byok_config(id_=23, base_model="anthropic/claude-3.5-sonnet") - session = _FakeSession([search_space, byok_cfg]) + search_space = _make_search_space(chat_model_id=23, user_id=user_id) + byok_model = _make_byok_model(id_=23, base_model="anthropic/claude-3.5-sonnet") + session = _FakeSession([search_space, byok_model]) owner, tier, base_model = await _resolve_agent_billing_for_search_space( session, search_space_id=42 @@ -398,13 +233,10 @@ async def test_positive_id_byok_is_always_free(): @pytest.mark.asyncio async def test_positive_id_byok_missing_returns_free_with_empty_base_model(): - """If the BYOK config row is missing/deleted but the search space still - points at it, the resolver still returns free (no debit) with an empty - base_model — billable_call's premium path is skipped, no harm done.""" from app.services.billable_calls import _resolve_agent_billing_for_search_space user_id = uuid4() - session = _FakeSession([_make_search_space(agent_llm_id=99, user_id=user_id)]) + session = _FakeSession([_make_search_space(chat_model_id=99, user_id=user_id)]) owner, tier, base_model = await _resolve_agent_billing_for_search_space( session, search_space_id=42 @@ -419,18 +251,18 @@ async def test_positive_id_byok_missing_returns_free_with_empty_base_model(): async def test_search_space_not_found_raises_value_error(): from app.services.billable_calls import _resolve_agent_billing_for_search_space - session = _FakeSession([None]) - with pytest.raises(ValueError, match="Search space"): - await _resolve_agent_billing_for_search_space(session, search_space_id=999) + await _resolve_agent_billing_for_search_space( + _FakeSession([None]), search_space_id=999 + ) @pytest.mark.asyncio -async def test_agent_llm_id_none_raises_value_error(): +async def test_chat_model_id_none_raises_value_error(): from app.services.billable_calls import _resolve_agent_billing_for_search_space user_id = uuid4() - session = _FakeSession([_make_search_space(agent_llm_id=None, user_id=user_id)]) + session = _FakeSession([_make_search_space(chat_model_id=None, user_id=user_id)]) - with pytest.raises(ValueError, match="agent_llm_id"): + with pytest.raises(ValueError, match="chat_model_id"): await _resolve_agent_billing_for_search_space(session, search_space_id=42) diff --git a/surfsense_backend/tests/unit/services/test_auto_model_pin_service.py b/surfsense_backend/tests/unit/services/test_auto_model_pin_service.py index d7eb32732..d7c12a6e0 100644 --- a/surfsense_backend/tests/unit/services/test_auto_model_pin_service.py +++ b/surfsense_backend/tests/unit/services/test_auto_model_pin_service.py @@ -32,8 +32,9 @@ class _FakeQuotaResult: class _FakeExecResult: - def __init__(self, thread): + def __init__(self, *, thread=None, scalars=None): self._thread = thread + self._scalars = scalars or [] def unique(self): return self @@ -41,19 +42,69 @@ class _FakeExecResult: def scalar_one_or_none(self): return self._thread + def scalars(self): + return SimpleNamespace(all=lambda: self._scalars) + class _FakeSession: - def __init__(self, thread): + def __init__(self, thread, *, models=None): self.thread = thread + self.models = models or [] self.commit_count = 0 + self.execute_count = 0 async def execute(self, _stmt): - return _FakeExecResult(self.thread) + self.execute_count += 1 + if self.execute_count == 1: + return _FakeExecResult(thread=self.thread) + return _FakeExecResult(scalars=self.models) async def commit(self): self.commit_count += 1 +def _set_global_llm_configs(monkeypatch, config, configs: list[dict]): + """Patch the new global model catalog shape from compact legacy cfg fixtures.""" + connections = [] + models = [] + for cfg in configs: + config_id = int(cfg["id"]) + connection_id = config_id - 100_000 + provider = cfg.get("provider") or cfg.get("litellm_provider") + model_name = cfg["model_name"] + connections.append( + { + "id": connection_id, + "provider": provider, + "scope": "GLOBAL", + "enabled": True, + } + ) + models.append( + { + "id": config_id, + "connection_id": connection_id, + "model_id": model_name, + "display_name": cfg.get("name") or model_name, + "supports_chat": cfg.get("supports_chat", True), + "supports_image_input": cfg.get("supports_image_input", True), + "supports_tools": cfg.get("supports_tools", True), + "supports_image_generation": cfg.get("supports_image_generation", False), + "capabilities_override": cfg.get("capabilities_override") or {}, + "billing_tier": cfg.get("billing_tier", "free"), + "catalog": { + "auto_pin_tier": cfg.get("auto_pin_tier"), + "quality_score": cfg.get("quality_score") + or cfg.get("quality_score_static"), + }, + } + ) + + monkeypatch.setattr(config, "GLOBAL_LLM_CONFIGS", configs) + monkeypatch.setattr(config, "GLOBAL_CONNECTIONS", connections) + monkeypatch.setattr(config, "GLOBAL_MODELS", models) + + def _thread( *, search_space_id: int = 10, @@ -71,9 +122,9 @@ async def test_auto_first_turn_pins_one_model(monkeypatch): from app.config import config session = _FakeSession(_thread()) - monkeypatch.setattr( + _set_global_llm_configs( + monkeypatch, config, - "GLOBAL_LLM_CONFIGS", [ {"id": -2, "litellm_provider": "openai", "model_name": "gpt-free", "api_key": "k1"}, { @@ -111,9 +162,9 @@ async def test_premium_eligible_auto_prefers_premium_over_free(monkeypatch): from app.config import config session = _FakeSession(_thread()) - monkeypatch.setattr( + _set_global_llm_configs( + monkeypatch, config, - "GLOBAL_LLM_CONFIGS", [ { "id": -2, @@ -158,9 +209,9 @@ async def test_premium_eligible_auto_prefers_azure_gpt_5_4(monkeypatch): from app.config import config session = _FakeSession(_thread()) - monkeypatch.setattr( + _set_global_llm_configs( + monkeypatch, config, - "GLOBAL_LLM_CONFIGS", [ { "id": -1, @@ -216,9 +267,9 @@ async def test_next_turn_reuses_existing_pin(monkeypatch): from app.config import config session = _FakeSession(_thread(pinned_llm_config_id=-1)) - monkeypatch.setattr( + _set_global_llm_configs( + monkeypatch, config, - "GLOBAL_LLM_CONFIGS", [ { "id": -1, @@ -257,9 +308,9 @@ async def test_premium_eligible_auto_can_pin_premium(monkeypatch): from app.config import config session = _FakeSession(_thread()) - monkeypatch.setattr( + _set_global_llm_configs( + monkeypatch, config, - "GLOBAL_LLM_CONFIGS", [ { "id": -1, @@ -295,9 +346,9 @@ async def test_premium_ineligible_auto_pins_free_only(monkeypatch): from app.config import config session = _FakeSession(_thread()) - monkeypatch.setattr( + _set_global_llm_configs( + monkeypatch, config, - "GLOBAL_LLM_CONFIGS", [ { "id": -2, @@ -340,9 +391,9 @@ async def test_pinned_premium_stays_premium_after_quota_exhaustion(monkeypatch): from app.config import config session = _FakeSession(_thread(pinned_llm_config_id=-1)) - monkeypatch.setattr( + _set_global_llm_configs( + monkeypatch, config, - "GLOBAL_LLM_CONFIGS", [ { "id": -2, @@ -385,9 +436,9 @@ async def test_force_repin_free_switches_auto_premium_pin_to_free(monkeypatch): from app.config import config session = _FakeSession(_thread(pinned_llm_config_id=-1)) - monkeypatch.setattr( + _set_global_llm_configs( + monkeypatch, config, - "GLOBAL_LLM_CONFIGS", [ { "id": -2, @@ -433,9 +484,9 @@ async def test_explicit_user_model_change_clears_pin(monkeypatch): from app.config import config session = _FakeSession(_thread(pinned_llm_config_id=-2)) - monkeypatch.setattr( + _set_global_llm_configs( + monkeypatch, config, - "GLOBAL_LLM_CONFIGS", [ {"id": -2, "litellm_provider": "openai", "model_name": "gpt-free", "api_key": "k1"}, ], @@ -458,9 +509,9 @@ async def test_invalid_pinned_config_repairs_with_new_pin(monkeypatch): from app.config import config session = _FakeSession(_thread(pinned_llm_config_id=-999)) - monkeypatch.setattr( + _set_global_llm_configs( + monkeypatch, config, - "GLOBAL_LLM_CONFIGS", [ {"id": -2, "litellm_provider": "openai", "model_name": "gpt-free", "api_key": "k1"}, ], @@ -487,7 +538,7 @@ async def test_invalid_pinned_config_repairs_with_new_pin(monkeypatch): # --------------------------------------------------------------------------- -# Quality-aware pin selection (Auto Fastest upgrade) +# Quality-aware pin selection (Auto upgrade) # --------------------------------------------------------------------------- @@ -498,9 +549,9 @@ async def test_health_gated_config_is_excluded_from_selection(monkeypatch): from app.config import config session = _FakeSession(_thread()) - monkeypatch.setattr( + _set_global_llm_configs( + monkeypatch, config, - "GLOBAL_LLM_CONFIGS", [ { "id": -1, @@ -550,9 +601,9 @@ async def test_tier_a_locks_first_premium_user_skips_or(monkeypatch): from app.config import config session = _FakeSession(_thread()) - monkeypatch.setattr( + _set_global_llm_configs( + monkeypatch, config, - "GLOBAL_LLM_CONFIGS", [ { "id": -1, @@ -602,9 +653,9 @@ async def test_tier_a_falls_through_to_or_when_a_pool_empty_for_user(monkeypatch from app.config import config session = _FakeSession(_thread()) - monkeypatch.setattr( + _set_global_llm_configs( + monkeypatch, config, - "GLOBAL_LLM_CONFIGS", [ { "id": -1, @@ -676,9 +727,9 @@ async def test_top_k_picks_only_high_score_models(monkeypatch): "quality_score": 10, "health_gated": False, } - monkeypatch.setattr( + _set_global_llm_configs( + monkeypatch, config, - "GLOBAL_LLM_CONFIGS", [*high_score_cfgs, low_score_trap], ) @@ -723,9 +774,9 @@ async def test_pin_reuse_survives_health_gating_for_existing_pin(monkeypatch): from app.config import config session = _FakeSession(_thread(pinned_llm_config_id=-1)) - monkeypatch.setattr( + _set_global_llm_configs( + monkeypatch, config, - "GLOBAL_LLM_CONFIGS", [ { "id": -1, @@ -775,9 +826,9 @@ async def test_pin_reuse_regression_existing_healthy_pin(monkeypatch): from app.config import config session = _FakeSession(_thread(pinned_llm_config_id=-1)) - monkeypatch.setattr( + _set_global_llm_configs( + monkeypatch, config, - "GLOBAL_LLM_CONFIGS", [ { "id": -1, @@ -833,9 +884,9 @@ async def test_runtime_cooled_down_pin_is_not_reused(monkeypatch): from app.config import config session = _FakeSession(_thread(pinned_llm_config_id=-1)) - monkeypatch.setattr( + _set_global_llm_configs( + monkeypatch, config, - "GLOBAL_LLM_CONFIGS", [ { "id": -1, @@ -886,9 +937,9 @@ async def test_clearing_runtime_cooldown_restores_pin_reuse(monkeypatch): from app.config import config session = _FakeSession(_thread(pinned_llm_config_id=-1)) - monkeypatch.setattr( + _set_global_llm_configs( + monkeypatch, config, - "GLOBAL_LLM_CONFIGS", [ { "id": -1, @@ -931,9 +982,9 @@ async def test_auto_pin_repin_excludes_previous_config_on_runtime_retry(monkeypa from app.config import config session = _FakeSession(_thread(pinned_llm_config_id=-1)) - monkeypatch.setattr( + _set_global_llm_configs( + monkeypatch, config, - "GLOBAL_LLM_CONFIGS", [ { "id": -1, diff --git a/surfsense_backend/tests/unit/services/test_image_gen_api_base_defense.py b/surfsense_backend/tests/unit/services/test_image_gen_api_base_defense.py index 63aa934a3..5850dfe23 100644 --- a/surfsense_backend/tests/unit/services/test_image_gen_api_base_defense.py +++ b/surfsense_backend/tests/unit/services/test_image_gen_api_base_defense.py @@ -15,15 +15,19 @@ async def test_global_openrouter_image_gen_sets_explicit_api_base(): """The global-config branch forwards the explicit OpenRouter base.""" from app.routes import image_generation_routes - cfg = { + global_model = { "id": -20_001, - "name": "GPT Image 1 (OpenRouter)", - "litellm_provider": "openrouter", - "model_name": "openai/gpt-image-1", + "connection_id": -101, + "model_id": "openai/gpt-image-1", + "supports_image_generation": True, + "capabilities_override": {}, + } + global_connection = { + "id": -101, + "provider": "openrouter", "api_key": "sk-or-test", - "api_base": "https://openrouter.ai/api/v1", - "api_version": None, - "litellm_params": {}, + "base_url": "https://openrouter.ai/api/v1", + "extra": {}, } captured: dict = {} @@ -33,7 +37,7 @@ async def test_global_openrouter_image_gen_sets_explicit_api_base(): return MagicMock(model_dump=lambda: {"data": []}, _hidden_params={}) image_gen = MagicMock() - image_gen.image_generation_config_id = cfg["id"] + image_gen.image_gen_model_id = global_model["id"] image_gen.prompt = "test" image_gen.n = 1 image_gen.quality = None @@ -43,14 +47,19 @@ async def test_global_openrouter_image_gen_sets_explicit_api_base(): image_gen.model = None search_space = MagicMock() - search_space.image_generation_config_id = cfg["id"] + search_space.image_gen_model_id = global_model["id"] session = MagicMock() with ( patch.object( image_generation_routes, - "_get_global_image_gen_config", - return_value=cfg, + "_get_global_model", + return_value=global_model, + ), + patch.object( + image_generation_routes, + "_get_global_connection", + return_value=global_connection, ), patch.object( image_generation_routes, @@ -74,15 +83,19 @@ async def test_generate_image_tool_global_sets_explicit_api_base(): generate_image as gi_module, ) - cfg = { + global_model = { "id": -20_001, - "name": "GPT Image 1 (OpenRouter)", - "litellm_provider": "openrouter", - "model_name": "openai/gpt-image-1", + "connection_id": -101, + "model_id": "openai/gpt-image-1", + "supports_image_generation": True, + "capabilities_override": {}, + } + global_connection = { + "id": -101, + "provider": "openrouter", "api_key": "sk-or-test", - "api_base": "https://openrouter.ai/api/v1", - "api_version": None, - "litellm_params": {}, + "base_url": "https://openrouter.ai/api/v1", + "extra": {}, } captured: dict = {} @@ -98,7 +111,7 @@ async def test_generate_image_tool_global_sets_explicit_api_base(): search_space = MagicMock() search_space.id = 1 - search_space.image_generation_config_id = cfg["id"] + search_space.image_gen_model_id = global_model["id"] session_cm = AsyncMock() session = AsyncMock() @@ -121,7 +134,8 @@ async def test_generate_image_tool_global_sets_explicit_api_base(): with ( patch.object(gi_module, "shielded_async_session", return_value=session_cm), - patch.object(gi_module, "_get_global_image_gen_config", return_value=cfg), + patch.object(gi_module, "_get_global_model", return_value=global_model), + patch.object(gi_module, "_get_global_connection", return_value=global_connection), patch.object( gi_module, "aimage_generation", side_effect=fake_aimage_generation ), diff --git a/surfsense_backend/tests/unit/services/test_openrouter_integration_service.py b/surfsense_backend/tests/unit/services/test_openrouter_integration_service.py index 9d4c1a04b..ee97aac4d 100644 --- a/surfsense_backend/tests/unit/services/test_openrouter_integration_service.py +++ b/surfsense_backend/tests/unit/services/test_openrouter_integration_service.py @@ -217,7 +217,7 @@ def test_generate_configs_drops_non_text_and_non_tool_models(): # --------------------------------------------------------------------------- -# _generate_image_gen_configs / _generate_vision_llm_configs +# _generate_image_gen_configs # --------------------------------------------------------------------------- @@ -263,7 +263,7 @@ def test_generate_image_gen_configs_filters_by_image_output(): # Each config must carry ``billing_tier`` for routing in image_generation_routes. for c in cfgs: assert c["billing_tier"] in {"free", "premium"} - assert c["litellm_provider"] == "openrouter" + assert c["provider"] == "openrouter" assert c[_OPENROUTER_DYNAMIC_MARKER] is True # Emit the OpenRouter base URL at source so every call path passes an # explicit api_base and cannot inherit a process-global endpoint. @@ -271,9 +271,7 @@ def test_generate_image_gen_configs_filters_by_image_output(): def test_generate_image_gen_configs_assigns_image_id_offset(): - """Image configs use a different id_offset (-20000) so their negative - IDs don't collide with chat configs (-10000) or vision configs (-30000). - """ + """Image configs use their own id_offset (-20000).""" from app.services.openrouter_integration_service import ( _generate_image_gen_configs, ) @@ -291,88 +289,3 @@ def test_generate_image_gen_configs_assigns_image_id_offset(): assert all(c["id"] < -20_000 + 1 for c in cfgs) assert all(c["id"] > -29_000_000 for c in cfgs) - -def test_generate_vision_llm_configs_filters_by_image_input_text_output(): - """Vision LLMs must accept image input AND emit text — pure image-gen - (no text out) and text-only (no image in) models are excluded. - """ - from app.services.openrouter_integration_service import ( - _generate_vision_llm_configs, - ) - - raw = [ - # GPT-4o: vision LLM (image in, text out) — must emit. - { - "id": "openai/gpt-4o", - "architecture": { - "input_modalities": ["text", "image"], - "output_modalities": ["text"], - }, - "context_length": 128_000, - "pricing": {"prompt": "0.000005", "completion": "0.000015"}, - }, - # Pure image generator — image *output*, no text out. Must NOT emit. - { - "id": "openai/gpt-image-1", - "architecture": { - "input_modalities": ["text"], - "output_modalities": ["image"], - }, - "context_length": 4_000, - "pricing": {"prompt": "0", "completion": "0"}, - }, - # Pure text model (no image in). Must NOT emit. - { - "id": "anthropic/claude-3-haiku", - "architecture": { - "input_modalities": ["text"], - "output_modalities": ["text"], - }, - "context_length": 200_000, - "pricing": {"prompt": "0.000001", "completion": "0.000005"}, - }, - ] - - cfgs = _generate_vision_llm_configs(raw, dict(_SETTINGS_BASE)) - names = {c["model_name"] for c in cfgs} - assert names == {"openai/gpt-4o"} - - cfg = cfgs[0] - assert cfg["billing_tier"] == "premium" - # Pricing carried inline so pricing_registration can register vision - # under ``openrouter/openai/gpt-4o`` even if the chat catalogue cache - # is cleared. - assert cfg["input_cost_per_token"] == pytest.approx(5e-6) - assert cfg["output_cost_per_token"] == pytest.approx(15e-6) - assert cfg[_OPENROUTER_DYNAMIC_MARKER] is True - # Emit the OpenRouter base URL at source so every call path passes an - # explicit api_base and cannot inherit a process-global endpoint. - assert cfg["api_base"] == "https://openrouter.ai/api/v1" - - -def test_generate_vision_llm_configs_drops_chat_only_filters(): - """A small-context vision model that doesn't advertise tool calling is - still a valid vision LLM for "describe this image" prompts. The chat - filters (``supports_tool_calling``, ``has_sufficient_context``) must - NOT be applied to vision emission. - """ - from app.services.openrouter_integration_service import ( - _generate_vision_llm_configs, - ) - - raw = [ - { - "id": "tiny/vision-mini", - "architecture": { - "input_modalities": ["text", "image"], - "output_modalities": ["text"], - }, - "supported_parameters": [], # no tools - "context_length": 4_000, # well below MIN_CONTEXT_LENGTH - "pricing": {"prompt": "0.0000001", "completion": "0.0000005"}, - } - ] - - cfgs = _generate_vision_llm_configs(raw, dict(_SETTINGS_BASE)) - assert len(cfgs) == 1 - assert cfgs[0]["model_name"] == "tiny/vision-mini" diff --git a/surfsense_backend/tests/unit/services/test_pricing_registration.py b/surfsense_backend/tests/unit/services/test_pricing_registration.py index c9adc6aac..ee2faf674 100644 --- a/surfsense_backend/tests/unit/services/test_pricing_registration.py +++ b/surfsense_backend/tests/unit/services/test_pricing_registration.py @@ -370,77 +370,3 @@ def test_register_continues_after_individual_failure(monkeypatch, caplog): assert any("custom-deployment" in payload for payload in successful_calls) -def test_vision_configs_registered_with_chat_shape(monkeypatch): - """``register_pricing_from_global_configs`` walks - ``GLOBAL_VISION_LLM_CONFIGS`` in addition to the chat configs so vision - calls (during indexing) bill correctly. Vision configs use the same - chat-shape token prices, but image-gen pricing is intentionally NOT - registered here (handled via ``response_cost`` in LiteLLM). - """ - from app.config import config - from app.services.pricing_registration import register_pricing_from_global_configs - - spy = _patch_register(monkeypatch) - _patch_openrouter_pricing( - monkeypatch, - {"openai/gpt-4o": {"prompt": "0.000005", "completion": "0.000015"}}, - ) - - # No chat configs — only vision. Proves the vision walk is a separate - # iteration, not piggy-backed on the chat list. - monkeypatch.setattr(config, "GLOBAL_LLM_CONFIGS", []) - monkeypatch.setattr( - config, - "GLOBAL_VISION_LLM_CONFIGS", - [ - { - "id": -1, - "litellm_provider": "openrouter", - "model_name": "openai/gpt-4o", - "billing_tier": "premium", - "input_cost_per_token": 5e-6, - "output_cost_per_token": 15e-6, - } - ], - ) - - register_pricing_from_global_configs() - - assert "openrouter/openai/gpt-4o" in spy.all_keys - payload_value = spy.calls[0]["openrouter/openai/gpt-4o"] - assert payload_value["mode"] == "chat" - assert payload_value["litellm_provider"] == "openrouter" - assert payload_value["input_cost_per_token"] == pytest.approx(5e-6) - assert payload_value["output_cost_per_token"] == pytest.approx(15e-6) - - -def test_vision_with_inline_pricing_when_or_cache_missing(monkeypatch): - """If the OpenRouter pricing cache misses a vision model (different - catalogue surface), the vision walk falls back to inline - ``input_cost_per_token``/``output_cost_per_token`` on the cfg itself. - """ - from app.config import config - from app.services.pricing_registration import register_pricing_from_global_configs - - spy = _patch_register(monkeypatch) - _patch_openrouter_pricing(monkeypatch, {}) - - monkeypatch.setattr(config, "GLOBAL_LLM_CONFIGS", []) - monkeypatch.setattr( - config, - "GLOBAL_VISION_LLM_CONFIGS", - [ - { - "id": -1, - "litellm_provider": "openrouter", - "model_name": "google/gemini-2.5-flash", - "billing_tier": "premium", - "input_cost_per_token": 1e-6, - "output_cost_per_token": 4e-6, - } - ], - ) - - register_pricing_from_global_configs() - - assert "openrouter/google/gemini-2.5-flash" in spy.all_keys diff --git a/surfsense_backend/tests/unit/services/test_quality_score.py b/surfsense_backend/tests/unit/services/test_quality_score.py index 369c8b8f3..cb3f7523a 100644 --- a/surfsense_backend/tests/unit/services/test_quality_score.py +++ b/surfsense_backend/tests/unit/services/test_quality_score.py @@ -1,4 +1,4 @@ -"""Unit tests for the Auto (Fastest) quality scoring module.""" +"""Unit tests for the Auto quality scoring module.""" from __future__ import annotations diff --git a/surfsense_backend/tests/unit/services/test_vision_llm_api_base_defense.py b/surfsense_backend/tests/unit/services/test_vision_llm_api_base_defense.py deleted file mode 100644 index 48dfc8e0b..000000000 --- a/surfsense_backend/tests/unit/services/test_vision_llm_api_base_defense.py +++ /dev/null @@ -1,77 +0,0 @@ -"""Vision LLM resolution must pass explicit per-config ``api_base``.""" - -from __future__ import annotations - -from unittest.mock import AsyncMock, MagicMock, patch - -import pytest - -pytestmark = pytest.mark.unit - - -@pytest.mark.asyncio -async def test_get_vision_llm_global_openrouter_sets_api_base(): - """Global negative-ID branch forwards the explicit OpenRouter base.""" - from app.services import llm_service - - cfg = { - "id": -30_001, - "name": "GPT-4o Vision (OpenRouter)", - "litellm_provider": "openrouter", - "model_name": "openai/gpt-4o", - "api_key": "sk-or-test", - "api_base": "https://openrouter.ai/api/v1", - "api_version": None, - "litellm_params": {}, - "billing_tier": "free", - } - - search_space = MagicMock() - search_space.id = 1 - search_space.user_id = "user-x" - search_space.vision_llm_config_id = cfg["id"] - - session = AsyncMock() - scalars = MagicMock() - scalars.first.return_value = search_space - result = MagicMock() - result.scalars.return_value = scalars - session.execute.return_value = result - - captured: dict = {} - - class FakeSanitized: - def __init__(self, **kwargs): - captured.update(kwargs) - - with ( - patch( - "app.services.vision_llm_router_service.get_global_vision_llm_config", - return_value=cfg, - ), - patch( - "app.agents.chat.runtime.llm_config.SanitizedChatLiteLLM", - new=FakeSanitized, - ), - ): - await llm_service.get_vision_llm(session=session, search_space_id=1) - - assert captured.get("api_base") == "https://openrouter.ai/api/v1" - assert captured["model"] == "openrouter/openai/gpt-4o" - - -def test_vision_router_deployment_sets_api_base_when_config_empty(): - """Auto-mode vision router carries explicit api_base into deployments.""" - from app.services.vision_llm_router_service import VisionLLMRouterService - - deployment = VisionLLMRouterService._config_to_deployment( - { - "model_name": "openai/gpt-4o", - "litellm_provider": "openrouter", - "api_key": "sk-or-test", - "api_base": "https://openrouter.ai/api/v1", - } - ) - assert deployment is not None - assert deployment["litellm_params"]["api_base"] == "https://openrouter.ai/api/v1" - assert deployment["litellm_params"]["model"] == "openrouter/openai/gpt-4o" diff --git a/surfsense_evals/README.md b/surfsense_evals/README.md index c755c4de6..e6fc52ca1 100644 --- a/surfsense_evals/README.md +++ b/surfsense_evals/README.md @@ -77,7 +77,7 @@ The walkthrough above is `--scenario head-to-head` (default): both arms answer w | `symmetric-cheap` | `--provider-model` (cheap, text-only) | `--provider-model` (same) | Does pre-extracted image context let a non-vision LLM reason over image-heavy docs? | | `cost-arbitrage` | `--native-arm-model` (vision) | `--provider-model` (cheap) | How close does SurfSense get to a vision-native baseline at a fraction of per-query cost?| -In all three modes the **ingest-time** vision LLM is set on the SearchSpace's `vision_llm_config_id` (auto-picked from the strongest registered global OpenRouter vision config — `claude-sonnet-4.5` > `claude-opus-4.7` > `gpt-5` > `gemini-2.5-pro`, override with `--vision-llm `). What changes is which slug the *answering* models hit per arm. +In all three modes the **ingest-time** vision LLM is set on the SearchSpace's `vision_model_id` (auto-picked from the strongest registered global OpenRouter vision-capable model — `claude-sonnet-4.5` > `claude-opus-4.7` > `gpt-5` > `gemini-2.5-pro`, override with `--vision-llm `). What changes is which slug the *answering* models hit per arm. ### Ingest with vision, evaluate with a non-vision LLM (`symmetric-cheap`) @@ -118,7 +118,7 @@ python -m surfsense_evals report --suite medical Notes: - `cost-arbitrage` requires both `--provider-model` (the cheap SurfSense slug) AND `--native-arm-model `. -- `--vision-llm ` is optional; if omitted the harness queries `GET /api/v1/global-vision-llm-configs` and auto-picks the strongest registered one. Pass `--no-vision-llm-setup` if you want to keep whatever vision config is already attached to the SearchSpace. +- `--vision-llm ` is optional; if omitted the harness queries `GET /api/v1/model-connections/global` and auto-picks the strongest registered vision-capable model. Pass `--no-vision-llm-setup` if you want to keep whatever vision model is already attached to the SearchSpace. - The runner's "looks text-only" warning is suppressed (or relabelled as informational) for `symmetric-cheap` so intentional asymmetry doesn't read as a misconfiguration. - All three scenario fields (`scenario`, `provider_model`, `native_arm_model`, `vision_provider_model`) are persisted to `state.json` and recorded in `run_artifact.extra` + the report header — no need to retrace what was set. diff --git a/surfsense_evals/data/multimodal_doc/runs/2026-05-14T00-53-19Z/parser_compare/run_artifact.json b/surfsense_evals/data/multimodal_doc/runs/2026-05-14T00-53-19Z/parser_compare/run_artifact.json index a4687f64a..b6c59e2bc 100644 --- a/surfsense_evals/data/multimodal_doc/runs/2026-05-14T00-53-19Z/parser_compare/run_artifact.json +++ b/surfsense_evals/data/multimodal_doc/runs/2026-05-14T00-53-19Z/parser_compare/run_artifact.json @@ -9,7 +9,7 @@ "llamacloud_premium_lc", "surfsense_agentic" ], - "agent_llm_id": -5138454, + "chat_model_id": -5138454, "concurrency": 2, "llm_model": "anthropic/claude-sonnet-4.5", "n_pdfs": 30, diff --git a/surfsense_evals/src/surfsense_evals/core/cli.py b/surfsense_evals/src/surfsense_evals/core/cli.py index 3d4d0fd24..17979fba0 100644 --- a/surfsense_evals/src/surfsense_evals/core/cli.py +++ b/surfsense_evals/src/surfsense_evals/core/cli.py @@ -2,7 +2,7 @@ Subcommands: -* ``setup --suite --provider-model [--agent-llm-id ]`` +* ``setup --suite --provider-model [--chat-model-id ]`` * ``teardown --suite `` * ``models list [--provider openrouter] [--grep ]`` * ``suites list`` @@ -18,7 +18,7 @@ publish its own flags. Design choices worth flagging: -* ``setup`` rejects ``agent_llm_id == 0`` (Auto / LiteLLM router) so +* ``setup`` rejects ``chat_model_id == 0`` (Auto / LiteLLM router) so per-question accuracy is reproducible. * ``setup`` validates that the picked LLM config has ``provider == "OPENROUTER"`` and ``model_name == --provider-model`` @@ -59,7 +59,6 @@ if sys.platform == "win32": from . import registry from .auth import CredentialError, acquire_token, client_with_auth from .clients import SearchSpaceClient -from .clients.search_space import LlmPreferences from .config import ( DEFAULT_SCENARIO, SCENARIOS, @@ -111,23 +110,30 @@ class LlmConfigEntry: def from_payload(cls, payload: dict[str, Any]) -> LlmConfigEntry: return cls( id=int(payload["id"]), - name=str(payload.get("name", "")), + name=str(payload.get("display_name") or payload.get("name") or ""), provider=str(payload.get("provider", "")).upper(), - model_name=str(payload.get("model_name", "")), + model_name=str(payload.get("model_id") or payload.get("model_name") or ""), raw=payload, ) async def _list_global_llm_configs(http: httpx.AsyncClient, base: str) -> list[LlmConfigEntry]: response = await http.get( - f"{base}/api/v1/global-new-llm-configs", + f"{base}/api/v1/model-connections/global", headers={"Accept": "application/json"}, ) response.raise_for_status() payload = response.json() if not isinstance(payload, list): - raise RuntimeError(f"Unexpected /global-new-llm-configs payload: {payload!r}") - return [LlmConfigEntry.from_payload(item) for item in payload] + raise RuntimeError(f"Unexpected /model-connections/global payload: {payload!r}") + entries: list[LlmConfigEntry] = [] + for connection in payload: + provider = connection.get("provider", "") + for model in connection.get("models") or []: + if not model.get("enabled", True) or not model.get("supports_chat"): + continue + entries.append(LlmConfigEntry.from_payload({**model, "provider": provider})) + return entries def _resolve_openrouter_id( @@ -143,8 +149,8 @@ def _resolve_openrouter_id( * If ``explicit_id`` is given: return it directly. The caller is then expected to GET-validate that the row's ``provider == "OPENROUTER"`` and ``model_name`` matches the slug. - That branch supports positive BYOK ``NewLLMConfig`` rows whose - slugs may overlap with global OpenRouter virtuals. + That branch supports positive BYOK model rows whose slugs may overlap + with global OpenRouter virtuals. * Otherwise: filter to ``provider == "OPENROUTER"`` and ``model_name == provider_model``. Expect exactly one match — raise with a friendly message otherwise. @@ -173,7 +179,7 @@ def _resolve_openrouter_id( listing = "\n".join(f" id={c.id} name={c.name!r}" for c in matches) raise RuntimeError( f"Multiple OpenRouter configs for slug '{provider_model}':\n{listing}\n" - "Pass --agent-llm-id to disambiguate." + "Pass --chat-model-id to disambiguate." ) return matches[0].id @@ -186,7 +192,7 @@ def _resolve_openrouter_id( async def _cmd_setup(args: argparse.Namespace) -> int: suite = args.suite provider_model: str = args.provider_model - explicit_id: int | None = args.agent_llm_id + explicit_id: int | None = args.chat_model_id scenario: str = args.scenario vision_llm_slug: str | None = args.vision_llm native_arm_model: str | None = args.native_arm_model @@ -194,7 +200,7 @@ async def _cmd_setup(args: argparse.Namespace) -> int: if explicit_id == 0: console.print( - "[red]agent_llm_id == 0 (Auto / LiteLLM router) is not allowed — " + "[red]chat_model_id == 0 (Auto / LiteLLM router) is not allowed — " "results would not be reproducible.[/red]" ) return 2 @@ -242,7 +248,7 @@ async def _cmd_setup(args: argparse.Namespace) -> int: candidates = await _list_global_llm_configs(http, config.surfsense_api_base) try: - agent_llm_id = _resolve_openrouter_id( + chat_model_id = _resolve_openrouter_id( candidates, provider_model, explicit_id=explicit_id ) except RuntimeError as exc: @@ -288,7 +294,7 @@ async def _cmd_setup(args: argparse.Namespace) -> int: vision_provider_model: str | None = None if not skip_vision_setup and (vision_required or vision_llm_slug is not None): try: - vision_candidates = await ss_client.list_global_vision_llm_configs() + vision_candidates = await ss_client.list_global_vision_models() resolved = resolve_vision_llm( vision_candidates, explicit_slug=vision_llm_slug ) @@ -302,37 +308,34 @@ async def _cmd_setup(args: argparse.Namespace) -> int: f"(id={vision_config_id}, selected_via={resolved.selected_via})." ) - pref_kwargs: dict[str, Any] = {"agent_llm_id": agent_llm_id} + role_kwargs: dict[str, Any] = {"chat_model_id": chat_model_id} if vision_config_id is not None: - pref_kwargs["vision_llm_config_id"] = vision_config_id + role_kwargs["vision_model_id"] = vision_config_id - await ss_client.set_llm_preferences(search_space_id, **pref_kwargs) - prefs = await ss_client.get_llm_preferences(search_space_id) - if not _validate_pin(prefs, provider_model): - agent = prefs.agent_llm or {} + await ss_client.set_model_roles(search_space_id, **role_kwargs) + roles = await ss_client.get_model_roles(search_space_id) + if roles.chat_model_id != chat_model_id: console.print( f"[red]LLM pin validation FAILED.[/red] After PUT, " - f"agent_llm.provider={agent.get('provider')!r}, " - f"model_name={agent.get('model_name')!r}; expected " - f"provider=OPENROUTER, model_name={provider_model!r}." + f"chat_model_id={roles.chat_model_id!r}; expected {chat_model_id!r}." ) return 2 - if vision_config_id is not None and prefs.vision_llm_config_id != vision_config_id: + if vision_config_id is not None and roles.vision_model_id != vision_config_id: console.print( f"[red]Vision LLM pin validation FAILED.[/red] After PUT, " - f"vision_llm_config_id={prefs.vision_llm_config_id!r}; " + f"vision_model_id={roles.vision_model_id!r}; " f"expected {vision_config_id!r}." ) return 2 suite_state = SuiteState( search_space_id=search_space_id, - agent_llm_id=agent_llm_id, + chat_model_id=chat_model_id, provider_model=provider_model, created_at=utc_iso_timestamp(), ingestion_maps=existing.ingestion_maps if existing else {}, scenario=scenario, - vision_llm_config_id=vision_config_id, + vision_model_id=vision_config_id, vision_provider_model=vision_provider_model, native_arm_model=native_arm_model, ) @@ -342,7 +345,7 @@ async def _cmd_setup(args: argparse.Namespace) -> int: f"suite={suite!r}", f"scenario={scenario!r}", f"search_space_id={suite_state.search_space_id}", - f"agent_llm_id={suite_state.agent_llm_id}", + f"chat_model_id={suite_state.chat_model_id}", f"provider_model={suite_state.provider_model!r}", ] if suite_state.vision_provider_model: @@ -353,14 +356,6 @@ async def _cmd_setup(args: argparse.Namespace) -> int: return 0 -def _validate_pin(prefs: LlmPreferences, provider_model: str) -> bool: - agent = prefs.agent_llm or {} - return ( - str(agent.get("provider", "")).upper() == "OPENROUTER" - and str(agent.get("model_name", "")) == provider_model - ) - - async def _cmd_teardown(args: argparse.Namespace) -> int: suite = args.suite config = load_config() @@ -654,10 +649,10 @@ def _build_parser() -> argparse.ArgumentParser: ), ) p_setup.add_argument( - "--agent-llm-id", + "--chat-model-id", type=int, default=None, - help="Optional override for BYOK NewLLMConfig rows.", + help="Optional explicit model id override.", ) p_setup.add_argument( "--scenario", diff --git a/surfsense_evals/src/surfsense_evals/core/clients/search_space.py b/surfsense_evals/src/surfsense_evals/core/clients/search_space.py index e2d37694d..efd4a571d 100644 --- a/surfsense_evals/src/surfsense_evals/core/clients/search_space.py +++ b/surfsense_evals/src/surfsense_evals/core/clients/search_space.py @@ -1,17 +1,16 @@ -"""Client for ``/api/v1/searchspaces`` and ``/api/v1/search-spaces/{id}/llm-preferences``. +"""Client for ``/api/v1/searchspaces`` and model-role endpoints. Verified against: * ``surfsense_backend/app/routes/search_spaces_routes.py:116`` (POST create) * ``surfsense_backend/app/routes/search_spaces_routes.py:234`` (GET by id) * ``surfsense_backend/app/routes/search_spaces_routes.py:422`` (DELETE soft-delete) -* ``surfsense_backend/app/routes/search_spaces_routes.py:698-849`` (GET/PUT llm-preferences) +* ``surfsense_backend/app/routes/model_connections_routes.py`` (GET/PUT model roles) * ``surfsense_backend/app/schemas/search_space.py:14`` (SearchSpaceCreate body) -* ``surfsense_backend/app/routes/vision_llm_routes.py:60`` (GET global vision configs) Note the inconsistent pluralisation in the backend: ``/searchspaces`` -(no hyphen) for CRUD, but ``/search-spaces`` (hyphenated) for the -``llm-preferences`` sub-resource. Both are mirrored verbatim here. +(no hyphen) for CRUD, but ``/search-spaces`` (hyphenated) for model-role +sub-resources. Both are mirrored verbatim here. """ from __future__ import annotations @@ -46,13 +45,8 @@ class SearchSpaceRow: @dataclass -class VisionLlmConfigEntry: - """Subset of one ``GET /global-vision-llm-configs`` row. - - The backend returns negative ids for global / OpenRouter-derived - vision configs and positive ids for per-user BYOK rows. Either is - accepted by ``set_llm_preferences(vision_llm_config_id=...)``. - """ +class VisionModelEntry: + """Subset of one GLOBAL model-connection model with image input support.""" id: int name: str @@ -62,45 +56,38 @@ class VisionLlmConfigEntry: raw: dict[str, Any] @classmethod - def from_payload(cls, payload: dict[str, Any]) -> VisionLlmConfigEntry: + def from_payload(cls, payload: dict[str, Any]) -> VisionModelEntry: return cls( id=int(payload.get("id", 0)), - name=str(payload.get("name", "")), + name=str(payload.get("display_name") or payload.get("model_id") or ""), provider=str(payload.get("provider", "")).upper(), - model_name=str(payload.get("model_name", "")), - is_auto_mode=bool(payload.get("is_auto_mode", False)), + model_name=str(payload.get("model_id", "")), + is_auto_mode=False, raw=payload, ) @dataclass -class LlmPreferences: - """Resolved LLM preferences with the embedded full config row. +class ModelRoles: + """Model role ids for a search space.""" - Mirrors ``LLMPreferencesRead`` from the backend so the lifecycle - command can introspect ``provider`` / ``model_name`` to validate the - OpenRouter pin. - """ - - agent_llm_id: int | None - image_generation_config_id: int | None - vision_llm_config_id: int | None - agent_llm: dict[str, Any] | None + chat_model_id: int | None + image_gen_model_id: int | None + vision_model_id: int | None raw: dict[str, Any] @classmethod - def from_payload(cls, payload: dict[str, Any]) -> LlmPreferences: + def from_payload(cls, payload: dict[str, Any]) -> ModelRoles: return cls( - agent_llm_id=payload.get("agent_llm_id"), - image_generation_config_id=payload.get("image_generation_config_id"), - vision_llm_config_id=payload.get("vision_llm_config_id"), - agent_llm=payload.get("agent_llm"), + chat_model_id=payload.get("chat_model_id"), + image_gen_model_id=payload.get("image_gen_model_id"), + vision_model_id=payload.get("vision_model_id"), raw=payload, ) class SearchSpaceClient: - """Thin wrapper around the SearchSpace + LLM preferences endpoints.""" + """Thin wrapper around the SearchSpace + model role endpoints.""" def __init__(self, http: httpx.AsyncClient, base_url: str) -> None: self._http = http @@ -139,64 +126,67 @@ class SearchSpaceClient: return response.raise_for_status() - async def get_llm_preferences(self, search_space_id: int) -> LlmPreferences: + async def get_model_roles(self, search_space_id: int) -> ModelRoles: response = await self._http.get( - f"{self._base}/api/v1/search-spaces/{search_space_id}/llm-preferences", + f"{self._base}/api/v1/search-spaces/{search_space_id}/model-roles", headers={"Accept": "application/json"}, ) response.raise_for_status() - return LlmPreferences.from_payload(response.json()) + return ModelRoles.from_payload(response.json()) - async def set_llm_preferences( + async def set_model_roles( self, search_space_id: int, *, - agent_llm_id: int | None = None, - image_generation_config_id: int | None = None, - vision_llm_config_id: int | None = None, - ) -> LlmPreferences: - """PUT a partial update to ``/search-spaces/{id}/llm-preferences``. + chat_model_id: int | None = None, + image_gen_model_id: int | None = None, + vision_model_id: int | None = None, + ) -> ModelRoles: + """PUT a partial update to ``/search-spaces/{id}/model-roles``. Backend uses ``model_dump(exclude_unset=True)`` so omitted fields are left unchanged. """ body: dict[str, Any] = {} - if agent_llm_id is not None: - body["agent_llm_id"] = agent_llm_id - if image_generation_config_id is not None: - body["image_generation_config_id"] = image_generation_config_id - if vision_llm_config_id is not None: - body["vision_llm_config_id"] = vision_llm_config_id + if chat_model_id is not None: + body["chat_model_id"] = chat_model_id + if image_gen_model_id is not None: + body["image_gen_model_id"] = image_gen_model_id + if vision_model_id is not None: + body["vision_model_id"] = vision_model_id response = await self._http.put( - f"{self._base}/api/v1/search-spaces/{search_space_id}/llm-preferences", + f"{self._base}/api/v1/search-spaces/{search_space_id}/model-roles", json=body, headers={"Accept": "application/json"}, ) response.raise_for_status() - return LlmPreferences.from_payload(response.json()) + return ModelRoles.from_payload(response.json()) - async def list_global_vision_llm_configs(self) -> list[VisionLlmConfigEntry]: - """List the registered global vision LLM configs. + async def list_global_vision_models(self) -> list[VisionModelEntry]: + """List registered GLOBAL models that can accept image input. - Used by ``setup`` to (a) resolve an explicit ``--vision-llm `` - to a config id and (b) auto-pick the strongest registered vision - config when the operator doesn't pass one. The ``Auto (Fastest)`` - entry (``id=0``) is filtered out — accuracy must be reproducible. + Used by ``setup`` to resolve ``--vision-llm `` or auto-pick a + reproducible ingest-time vision model. """ response = await self._http.get( - f"{self._base}/api/v1/global-vision-llm-configs", + f"{self._base}/api/v1/model-connections/global", headers={"Accept": "application/json"}, ) response.raise_for_status() payload = response.json() if not isinstance(payload, list): raise RuntimeError( - f"Unexpected /global-vision-llm-configs payload: {payload!r}" + f"Unexpected /model-connections/global payload: {payload!r}" ) - return [ - VisionLlmConfigEntry.from_payload(item) - for item in payload - if not bool(item.get("is_auto_mode", False)) - ] + entries: list[VisionModelEntry] = [] + for connection in payload: + provider = str(connection.get("provider", "")) + for model in connection.get("models") or []: + if not model.get("enabled", True) or not model.get("supports_image_input"): + continue + entries.append( + VisionModelEntry.from_payload({**model, "provider": provider}) + ) + return entries diff --git a/surfsense_evals/src/surfsense_evals/core/config.py b/surfsense_evals/src/surfsense_evals/core/config.py index 164955914..9a5a71e89 100644 --- a/surfsense_evals/src/surfsense_evals/core/config.py +++ b/surfsense_evals/src/surfsense_evals/core/config.py @@ -147,35 +147,35 @@ class SuiteState: """Per-suite persisted state. ``provider_model`` is the slug pinned to the SearchSpace's - ``agent_llm`` — what answers SurfSense queries (and what the native + ``chat_model_id`` — what answers SurfSense queries (and what the native arm uses too, unless ``native_arm_model`` is set for cost-arbitrage). - ``vision_provider_model`` is the slug of the OpenRouter vision LLM - config attached to the SearchSpace's ``vision_llm_config_id`` — what + ``vision_provider_model`` is the slug of the OpenRouter vision model + attached to the SearchSpace's ``vision_model_id`` — what SurfSense uses to extract image content at ingest time when ``use_vision_llm=True``. ``None`` means no vision config was attached at setup (legacy or text-only suite). """ search_space_id: int - agent_llm_id: int + chat_model_id: int provider_model: str created_at: str ingestion_maps: dict[str, str] = field(default_factory=dict) scenario: str = DEFAULT_SCENARIO - vision_llm_config_id: int | None = None + vision_model_id: int | None = None vision_provider_model: str | None = None native_arm_model: str | None = None def to_dict(self) -> dict[str, Any]: return { "search_space_id": self.search_space_id, - "agent_llm_id": self.agent_llm_id, + "chat_model_id": self.chat_model_id, "provider_model": self.provider_model, "created_at": self.created_at, "ingestion_maps": dict(self.ingestion_maps), "scenario": self.scenario, - "vision_llm_config_id": self.vision_llm_config_id, + "vision_model_id": self.vision_model_id, "vision_provider_model": self.vision_provider_model, "native_arm_model": self.native_arm_model, } @@ -187,15 +187,16 @@ class SuiteState: scenario = str(payload.get("scenario") or DEFAULT_SCENARIO) if scenario not in SCENARIOS: scenario = DEFAULT_SCENARIO - raw_vision_id = payload.get("vision_llm_config_id") + raw_chat_id = payload.get("chat_model_id") + raw_vision_id = payload.get("vision_model_id") return cls( search_space_id=int(payload["search_space_id"]), - agent_llm_id=int(payload["agent_llm_id"]), + chat_model_id=int(raw_chat_id), provider_model=str(payload["provider_model"]), created_at=str(payload.get("created_at") or ""), ingestion_maps=dict(payload.get("ingestion_maps") or {}), scenario=scenario, - vision_llm_config_id=int(raw_vision_id) if raw_vision_id is not None else None, + vision_model_id=int(raw_vision_id) if raw_vision_id is not None else None, vision_provider_model=( str(payload["vision_provider_model"]) if payload.get("vision_provider_model") diff --git a/surfsense_evals/src/surfsense_evals/core/registry.py b/surfsense_evals/src/surfsense_evals/core/registry.py index cc8b725e0..65f64c39a 100644 --- a/surfsense_evals/src/surfsense_evals/core/registry.py +++ b/surfsense_evals/src/surfsense_evals/core/registry.py @@ -53,8 +53,8 @@ class RunContext: return self.suite_state.search_space_id @property - def agent_llm_id(self) -> int: - return self.suite_state.agent_llm_id + def chat_model_id(self) -> int: + return self.suite_state.chat_model_id @property def provider_model(self) -> str: diff --git a/surfsense_evals/src/surfsense_evals/core/vision_llm.py b/surfsense_evals/src/surfsense_evals/core/vision_llm.py index ae96f1285..5d5e2c6d1 100644 --- a/surfsense_evals/src/surfsense_evals/core/vision_llm.py +++ b/surfsense_evals/src/surfsense_evals/core/vision_llm.py @@ -3,8 +3,8 @@ Two responsibilities: 1. Resolve an explicit ``--vision-llm `` to a global OpenRouter - vision LLM config id that ``set_llm_preferences(vision_llm_config_id=...)`` - can accept. + vision-capable model id that ``set_model_roles(vision_model_id=...)`` can + accept. 2. Auto-pick the strongest registered vision config when the operator doesn't pass ``--vision-llm`` but the scenario / benchmark needs one. diff --git a/surfsense_evals/src/surfsense_evals/suites/medical/medxpertqa/runner.py b/surfsense_evals/src/surfsense_evals/suites/medical/medxpertqa/runner.py index e1a830138..ac0651996 100644 --- a/surfsense_evals/src/surfsense_evals/suites/medical/medxpertqa/runner.py +++ b/surfsense_evals/src/surfsense_evals/suites/medical/medxpertqa/runner.py @@ -371,7 +371,7 @@ class MedXpertQAMMBenchmark: "provider_model": ctx.provider_model, "native_arm_model": native_arm_model, "vision_provider_model": ctx.vision_provider_model, - "agent_llm_id": ctx.agent_llm_id, + "chat_model_id": ctx.chat_model_id, "ingest_settings": ingest_settings, }, ) diff --git a/surfsense_evals/src/surfsense_evals/suites/multimodal_doc/mmlongbench/runner.py b/surfsense_evals/src/surfsense_evals/suites/multimodal_doc/mmlongbench/runner.py index 95a1e15eb..b7685766e 100644 --- a/surfsense_evals/src/surfsense_evals/suites/multimodal_doc/mmlongbench/runner.py +++ b/surfsense_evals/src/surfsense_evals/suites/multimodal_doc/mmlongbench/runner.py @@ -391,7 +391,7 @@ class MMLongBenchDocBenchmark: "provider_model": ctx.provider_model, "native_arm_model": native_arm_model, "vision_provider_model": ctx.vision_provider_model, - "agent_llm_id": ctx.agent_llm_id, + "chat_model_id": ctx.chat_model_id, "ingest_settings": ingest_settings, }, ) diff --git a/surfsense_evals/src/surfsense_evals/suites/multimodal_doc/parser_compare/runner.py b/surfsense_evals/src/surfsense_evals/suites/multimodal_doc/parser_compare/runner.py index e71dffa65..2c4a0ffe4 100644 --- a/surfsense_evals/src/surfsense_evals/suites/multimodal_doc/parser_compare/runner.py +++ b/surfsense_evals/src/surfsense_evals/suites/multimodal_doc/parser_compare/runner.py @@ -554,7 +554,7 @@ class ParserCompareBenchmark: "scenario": ctx.scenario, "provider_model": ctx.provider_model, "vision_provider_model": ctx.vision_provider_model, - "agent_llm_id": ctx.agent_llm_id, + "chat_model_id": ctx.chat_model_id, "preprocess_tariff": { "basic_per_1k_pages": 1.0, "premium_per_1k_pages": 10.0, diff --git a/surfsense_evals/src/surfsense_evals/suites/research/crag/runner.py b/surfsense_evals/src/surfsense_evals/suites/research/crag/runner.py index 8b759e0d8..654c261a2 100644 --- a/surfsense_evals/src/surfsense_evals/suites/research/crag/runner.py +++ b/surfsense_evals/src/surfsense_evals/suites/research/crag/runner.py @@ -467,7 +467,7 @@ class CragBenchmark: "provider_model": ctx.provider_model, "native_arm_model": ctx.native_arm_model, "vision_provider_model": ctx.vision_provider_model, - "agent_llm_id": ctx.agent_llm_id, + "chat_model_id": ctx.chat_model_id, "ingest_settings": ingest_settings, "per_page_char_cap": per_page_char_cap, "max_output_tokens": max_output_tokens, diff --git a/surfsense_evals/src/surfsense_evals/suites/research/frames/runner.py b/surfsense_evals/src/surfsense_evals/suites/research/frames/runner.py index 9c0e16b00..450c7ddd6 100644 --- a/surfsense_evals/src/surfsense_evals/suites/research/frames/runner.py +++ b/surfsense_evals/src/surfsense_evals/suites/research/frames/runner.py @@ -372,7 +372,7 @@ class FramesBenchmark: "provider_model": ctx.provider_model, "native_arm_model": ctx.native_arm_model, "vision_provider_model": ctx.vision_provider_model, - "agent_llm_id": ctx.agent_llm_id, + "chat_model_id": ctx.chat_model_id, "ingest_settings": ingest_settings, "bare_arm_label": "bare_llm", }, diff --git a/surfsense_evals/tests/core/test_clients.py b/surfsense_evals/tests/core/test_clients.py index 611408703..aa98f0ad4 100644 --- a/surfsense_evals/tests/core/test_clients.py +++ b/surfsense_evals/tests/core/test_clients.py @@ -63,29 +63,22 @@ async def test_delete_search_space_idempotent_on_404(respx_mock, http): @pytest.mark.asyncio @respx.mock(base_url=_BASE) -async def test_set_llm_preferences_partial_update(respx_mock, http): - route = respx_mock.put("/api/v1/search-spaces/42/llm-preferences").mock( +async def test_set_model_roles_partial_update(respx_mock, http): + route = respx_mock.put("/api/v1/search-spaces/42/model-roles").mock( return_value=httpx.Response( 200, json={ - "agent_llm_id": -10042, - "agent_llm_id": None, - "image_generation_config_id": None, - "vision_llm_config_id": None, - "agent_llm": { - "id": -10042, - "provider": "OPENROUTER", - "model_name": "anthropic/claude-sonnet-4.5", - }, + "chat_model_id": -10042, + "image_gen_model_id": None, + "vision_model_id": None, }, ) ) client = SearchSpaceClient(http, _BASE) - prefs = await client.set_llm_preferences(42, agent_llm_id=-10042) - assert prefs.agent_llm_id == -10042 - assert prefs.agent_llm["provider"] == "OPENROUTER" + roles = await client.set_model_roles(42, chat_model_id=-10042) + assert roles.chat_model_id == -10042 sent_body = json.loads(route.calls[-1].request.content) - assert sent_body == {"agent_llm_id": -10042} + assert sent_body == {"chat_model_id": -10042} # --------------------------------------------------------------------------- diff --git a/surfsense_evals/tests/core/test_config.py b/surfsense_evals/tests/core/test_config.py index f7b8f7249..6f9671c86 100644 --- a/surfsense_evals/tests/core/test_config.py +++ b/surfsense_evals/tests/core/test_config.py @@ -41,14 +41,14 @@ def test_state_roundtrip_per_suite(tmp_env): # noqa: ARG001 assert get_suite_state(config, "medical") is None state = SuiteState( search_space_id=1, - agent_llm_id=-10042, + chat_model_id=-10042, provider_model="anthropic/claude-sonnet-4.5", created_at="2026-05-11T20-30-00Z", ) set_suite_state(config, "medical", state) legal = SuiteState( search_space_id=2, - agent_llm_id=-1, + chat_model_id=-1, provider_model="openai/gpt-5", created_at="2026-05-11T21-00-00Z", ) @@ -84,25 +84,19 @@ def test_paths_are_per_suite(tmp_env): # noqa: ARG001 # --------------------------------------------------------------------------- -def test_legacy_state_back_compat_defaults_to_head_to_head(): - """state.json files written before scenarios shipped must still load. +def test_minimal_state_defaults_to_head_to_head(): + """Missing scenario / vision / native fields default safely.""" - Missing ``scenario`` / ``vision_*`` / ``native_arm_model`` keys all - default to ``head-to-head`` / ``None`` so old setups keep working - after upgrade — the runner's behaviour exactly mirrors the legacy - one (both arms answer with ``provider_model``). - """ - - legacy = { + payload = { "search_space_id": 7, - "agent_llm_id": -123, + "chat_model_id": -123, "provider_model": "anthropic/claude-sonnet-4.5", "created_at": "2026-05-11T20-30-00Z", "ingestion_maps": {}, } - state = SuiteState.from_dict(legacy) + state = SuiteState.from_dict(payload) assert state.scenario == DEFAULT_SCENARIO == "head-to-head" - assert state.vision_llm_config_id is None + assert state.vision_model_id is None assert state.vision_provider_model is None assert state.native_arm_model is None # The native arm should still answer with the same slug as SurfSense. @@ -118,7 +112,7 @@ def test_unknown_scenario_falls_back_to_default(): payload = { "search_space_id": 1, - "agent_llm_id": -1, + "chat_model_id": -1, "provider_model": "openai/gpt-5", "scenario": "unknown-scenario-name", } @@ -130,11 +124,11 @@ def test_cost_arbitrage_state_persists_native_arm_model(tmp_env): # noqa: ARG00 config = load_config() state = SuiteState( search_space_id=42, - agent_llm_id=-1, + chat_model_id=-1, provider_model="openai/gpt-5.4-mini", created_at="2026-05-11T20-30-00Z", scenario="cost-arbitrage", - vision_llm_config_id=-101, + vision_model_id=-101, vision_provider_model="anthropic/claude-sonnet-4.5", native_arm_model="anthropic/claude-sonnet-4.5", ) @@ -142,7 +136,7 @@ def test_cost_arbitrage_state_persists_native_arm_model(tmp_env): # noqa: ARG00 fetched = get_suite_state(config, "medical") assert fetched.scenario == "cost-arbitrage" - assert fetched.vision_llm_config_id == -101 + assert fetched.vision_model_id == -101 assert fetched.vision_provider_model == "anthropic/claude-sonnet-4.5" assert fetched.native_arm_model == "anthropic/claude-sonnet-4.5" # Cost arbitrage's whole point: native arm slug != surfsense slug. diff --git a/surfsense_evals/tests/test_integration_smoke.py b/surfsense_evals/tests/test_integration_smoke.py index 493c04c25..1c89ae5ab 100644 --- a/surfsense_evals/tests/test_integration_smoke.py +++ b/surfsense_evals/tests/test_integration_smoke.py @@ -27,7 +27,7 @@ async def test_smoke_against_localhost(): pytest.skip("No credentials in environment; skipping integration smoke") bundle = await acquire_token(config) async with client_with_auth(config, bundle) as client: - response = await client.get(f"{config.surfsense_api_base}/api/v1/global-new-llm-configs") + response = await client.get(f"{config.surfsense_api_base}/api/v1/model-connections/global") try: response.raise_for_status() except httpx.HTTPStatusError as exc: diff --git a/surfsense_web/app/dashboard/[search_space_id]/search-space-settings/image-models/page.tsx b/surfsense_web/app/dashboard/[search_space_id]/search-space-settings/image-models/page.tsx deleted file mode 100644 index b300f8078..000000000 --- a/surfsense_web/app/dashboard/[search_space_id]/search-space-settings/image-models/page.tsx +++ /dev/null @@ -1,6 +0,0 @@ -import { ImageModelManager } from "@/components/settings/image-model-manager"; - -export default async function Page({ params }: { params: Promise<{ search_space_id: string }> }) { - const { search_space_id } = await params; - return ; -} diff --git a/surfsense_web/app/dashboard/[search_space_id]/search-space-settings/roles/page.tsx b/surfsense_web/app/dashboard/[search_space_id]/search-space-settings/roles/page.tsx deleted file mode 100644 index 5bad50cd3..000000000 --- a/surfsense_web/app/dashboard/[search_space_id]/search-space-settings/roles/page.tsx +++ /dev/null @@ -1,6 +0,0 @@ -import { LLMRoleManager } from "@/components/settings/llm-role-manager"; - -export default async function Page({ params }: { params: Promise<{ search_space_id: string }> }) { - const { search_space_id } = await params; - return ; -} diff --git a/surfsense_web/app/dashboard/[search_space_id]/search-space-settings/vision-models/page.tsx b/surfsense_web/app/dashboard/[search_space_id]/search-space-settings/vision-models/page.tsx deleted file mode 100644 index 06aea003a..000000000 --- a/surfsense_web/app/dashboard/[search_space_id]/search-space-settings/vision-models/page.tsx +++ /dev/null @@ -1,6 +0,0 @@ -import { VisionModelManager } from "@/components/settings/vision-model-manager"; - -export default async function Page({ params }: { params: Promise<{ search_space_id: string }> }) { - const { search_space_id } = await params; - return ; -} diff --git a/surfsense_web/atoms/image-gen-config/image-gen-config-mutation.atoms.ts b/surfsense_web/atoms/image-gen-config/image-gen-config-mutation.atoms.ts deleted file mode 100644 index 922c398c9..000000000 --- a/surfsense_web/atoms/image-gen-config/image-gen-config-mutation.atoms.ts +++ /dev/null @@ -1,96 +0,0 @@ -import { atomWithMutation } from "jotai-tanstack-query"; -import { toast } from "sonner"; -import type { - CreateImageGenConfigRequest, - CreateImageGenConfigResponse, - DeleteImageGenConfigResponse, - GetImageGenConfigsResponse, - UpdateImageGenConfigRequest, - UpdateImageGenConfigResponse, -} from "@/contracts/types/new-llm-config.types"; -import { imageGenConfigApiService } from "@/lib/apis/image-gen-config-api.service"; -import { cacheKeys } from "@/lib/query-client/cache-keys"; -import { queryClient } from "@/lib/query-client/client"; -import { activeSearchSpaceIdAtom } from "../search-spaces/search-space-query.atoms"; - -/** - * Mutation atom for creating a new ImageGenerationConfig - */ -export const createImageGenConfigMutationAtom = atomWithMutation((get) => { - const searchSpaceId = get(activeSearchSpaceIdAtom); - - return { - mutationKey: ["image-gen-configs", "create"], - meta: { suppressGlobalErrorToast: true }, - enabled: !!searchSpaceId, - mutationFn: async (request: CreateImageGenConfigRequest) => { - return imageGenConfigApiService.createConfig(request); - }, - onSuccess: (_: CreateImageGenConfigResponse, request: CreateImageGenConfigRequest) => { - toast.success(`${request.name} created`); - queryClient.invalidateQueries({ - queryKey: cacheKeys.imageGenConfigs.all(Number(searchSpaceId)), - }); - }, - onError: (error: Error) => { - toast.error(error.message || "Failed to create image model"); - }, - }; -}); - -/** - * Mutation atom for updating an existing ImageGenerationConfig - */ -export const updateImageGenConfigMutationAtom = atomWithMutation((get) => { - const searchSpaceId = get(activeSearchSpaceIdAtom); - - return { - mutationKey: ["image-gen-configs", "update"], - meta: { suppressGlobalErrorToast: true }, - enabled: !!searchSpaceId, - mutationFn: async (request: UpdateImageGenConfigRequest) => { - return imageGenConfigApiService.updateConfig(request); - }, - onSuccess: (_: UpdateImageGenConfigResponse, request: UpdateImageGenConfigRequest) => { - toast.success(`${request.data.name ?? "Configuration"} updated`); - queryClient.invalidateQueries({ - queryKey: cacheKeys.imageGenConfigs.all(Number(searchSpaceId)), - }); - queryClient.invalidateQueries({ - queryKey: cacheKeys.imageGenConfigs.byId(request.id), - }); - }, - onError: (error: Error) => { - toast.error(error.message || "Failed to update image model"); - }, - }; -}); - -/** - * Mutation atom for deleting an ImageGenerationConfig - */ -export const deleteImageGenConfigMutationAtom = atomWithMutation((get) => { - const searchSpaceId = get(activeSearchSpaceIdAtom); - - return { - mutationKey: ["image-gen-configs", "delete"], - meta: { suppressGlobalErrorToast: true }, - enabled: !!searchSpaceId, - mutationFn: async (request: { id: number; name: string }) => { - return imageGenConfigApiService.deleteConfig(request.id); - }, - onSuccess: (_: DeleteImageGenConfigResponse, request: { id: number; name: string }) => { - toast.success(`${request.name} deleted`); - queryClient.setQueryData( - cacheKeys.imageGenConfigs.all(Number(searchSpaceId)), - (oldData: GetImageGenConfigsResponse | undefined) => { - if (!oldData) return oldData; - return oldData.filter((config) => config.id !== request.id); - } - ); - }, - onError: (error: Error) => { - toast.error(error.message || "Failed to delete image model"); - }, - }; -}); diff --git a/surfsense_web/atoms/image-gen-config/image-gen-config-query.atoms.ts b/surfsense_web/atoms/image-gen-config/image-gen-config-query.atoms.ts deleted file mode 100644 index a45e69a03..000000000 --- a/surfsense_web/atoms/image-gen-config/image-gen-config-query.atoms.ts +++ /dev/null @@ -1,33 +0,0 @@ -import { atomWithQuery } from "jotai-tanstack-query"; -import { imageGenConfigApiService } from "@/lib/apis/image-gen-config-api.service"; -import { cacheKeys } from "@/lib/query-client/cache-keys"; -import { activeSearchSpaceIdAtom } from "../search-spaces/search-space-query.atoms"; - -/** - * Query atom for fetching user-created image gen configs for the active search space - */ -export const imageGenConfigsAtom = atomWithQuery((get) => { - const searchSpaceId = get(activeSearchSpaceIdAtom); - - return { - queryKey: cacheKeys.imageGenConfigs.all(Number(searchSpaceId)), - enabled: !!searchSpaceId, - staleTime: 5 * 60 * 1000, // 5 minutes - queryFn: async () => { - return imageGenConfigApiService.getConfigs(Number(searchSpaceId)); - }, - }; -}); - -/** - * Query atom for fetching global image gen configs (from YAML, negative IDs) - */ -export const globalImageGenConfigsAtom = atomWithQuery(() => { - return { - queryKey: cacheKeys.imageGenConfigs.global(), - staleTime: 10 * 60 * 1000, // 10 minutes - global configs rarely change - queryFn: async () => { - return imageGenConfigApiService.getGlobalConfigs(); - }, - }; -}); diff --git a/surfsense_web/atoms/new-llm-config/new-llm-config-mutation.atoms.ts b/surfsense_web/atoms/new-llm-config/new-llm-config-mutation.atoms.ts deleted file mode 100644 index 476d89d4c..000000000 --- a/surfsense_web/atoms/new-llm-config/new-llm-config-mutation.atoms.ts +++ /dev/null @@ -1,132 +0,0 @@ -import { atomWithMutation } from "jotai-tanstack-query"; -import { toast } from "sonner"; -import type { - CreateNewLLMConfigRequest, - CreateNewLLMConfigResponse, - DeleteNewLLMConfigRequest, - DeleteNewLLMConfigResponse, - GetNewLLMConfigsResponse, - UpdateLLMPreferencesRequest, - UpdateNewLLMConfigRequest, - UpdateNewLLMConfigResponse, -} from "@/contracts/types/new-llm-config.types"; -import { newLLMConfigApiService } from "@/lib/apis/new-llm-config-api.service"; -import { cacheKeys } from "@/lib/query-client/cache-keys"; -import { queryClient } from "@/lib/query-client/client"; -import { activeSearchSpaceIdAtom } from "../search-spaces/search-space-query.atoms"; - -/** - * Mutation atom for creating a new NewLLMConfig - */ -export const createNewLLMConfigMutationAtom = atomWithMutation((get) => { - const searchSpaceId = get(activeSearchSpaceIdAtom); - - return { - mutationKey: ["new-llm-configs", "create"], - meta: { suppressGlobalErrorToast: true }, - enabled: !!searchSpaceId, - mutationFn: async (request: CreateNewLLMConfigRequest) => { - return newLLMConfigApiService.createConfig(request); - }, - onSuccess: (_: CreateNewLLMConfigResponse, request: CreateNewLLMConfigRequest) => { - toast.success(`${request.name} created`); - queryClient.invalidateQueries({ - queryKey: cacheKeys.newLLMConfigs.all(Number(searchSpaceId)), - }); - }, - onError: (error: Error) => { - toast.error(error.message || "Failed to create model"); - }, - }; -}); - -/** - * Mutation atom for updating an existing NewLLMConfig - */ -export const updateNewLLMConfigMutationAtom = atomWithMutation((get) => { - const searchSpaceId = get(activeSearchSpaceIdAtom); - - return { - mutationKey: ["new-llm-configs", "update"], - meta: { suppressGlobalErrorToast: true }, - enabled: !!searchSpaceId, - mutationFn: async (request: UpdateNewLLMConfigRequest) => { - return newLLMConfigApiService.updateConfig(request); - }, - onSuccess: (_: UpdateNewLLMConfigResponse, request: UpdateNewLLMConfigRequest) => { - toast.success(`${request.data.name ?? "Configuration"} updated`); - queryClient.invalidateQueries({ - queryKey: cacheKeys.newLLMConfigs.all(Number(searchSpaceId)), - }); - queryClient.invalidateQueries({ - queryKey: cacheKeys.newLLMConfigs.byId(request.id), - }); - }, - onError: (error: Error) => { - toast.error(error.message || "Failed to update"); - }, - }; -}); - -/** - * Mutation atom for deleting a NewLLMConfig - */ -export const deleteNewLLMConfigMutationAtom = atomWithMutation((get) => { - const searchSpaceId = get(activeSearchSpaceIdAtom); - - return { - mutationKey: ["new-llm-configs", "delete"], - meta: { suppressGlobalErrorToast: true }, - enabled: !!searchSpaceId, - mutationFn: async (request: DeleteNewLLMConfigRequest & { name: string }) => { - return newLLMConfigApiService.deleteConfig({ id: request.id }); - }, - onSuccess: ( - _: DeleteNewLLMConfigResponse, - request: DeleteNewLLMConfigRequest & { name: string } - ) => { - toast.success(`${request.name} deleted`); - queryClient.setQueryData( - cacheKeys.newLLMConfigs.all(Number(searchSpaceId)), - (oldData: GetNewLLMConfigsResponse | undefined) => { - if (!oldData) return oldData; - return oldData.filter((config) => config.id !== request.id); - } - ); - }, - onError: (error: Error) => { - toast.error(error.message || "Failed to delete"); - }, - }; -}); - -/** - * Mutation atom for updating LLM preferences (role assignments) - */ -export const updateLLMPreferencesMutationAtom = atomWithMutation((get) => { - const searchSpaceId = get(activeSearchSpaceIdAtom); - - return { - mutationKey: ["llm-preferences", "update"], - meta: { suppressGlobalErrorToast: true }, - enabled: !!searchSpaceId, - mutationFn: async (request: UpdateLLMPreferencesRequest) => { - return newLLMConfigApiService.updateLLMPreferences(request); - }, - onSuccess: (_data, request: UpdateLLMPreferencesRequest) => { - queryClient.setQueryData( - cacheKeys.newLLMConfigs.preferences(Number(searchSpaceId)), - (old: Record | undefined) => ({ ...old, ...request.data }) - ); - // Automation eligibility is derived from these model preferences - // (agent/image/vision). Invalidate it so the automations gate alert - // reflects the new selection without a manual refresh. - queryClient.invalidateQueries({ - queryKey: cacheKeys.automations.modelEligibility(Number(searchSpaceId)), - }); - }, - onError: (error: Error) => { - toast.error(error.message || "Failed to update LLM preferences"); - }, - }; -}); diff --git a/surfsense_web/atoms/new-llm-config/new-llm-config-query.atoms.ts b/surfsense_web/atoms/new-llm-config/new-llm-config-query.atoms.ts deleted file mode 100644 index 410d061e5..000000000 --- a/surfsense_web/atoms/new-llm-config/new-llm-config-query.atoms.ts +++ /dev/null @@ -1,98 +0,0 @@ -import { atomWithQuery } from "jotai-tanstack-query"; -import type { LLMModel } from "@/contracts/enums/llm-models"; -import { LLM_MODELS } from "@/contracts/enums/llm-models"; -import { newLLMConfigApiService } from "@/lib/apis/new-llm-config-api.service"; -import { getBearerToken } from "@/lib/auth-utils"; -import { cacheKeys } from "@/lib/query-client/cache-keys"; -import { activeSearchSpaceIdAtom } from "../search-spaces/search-space-query.atoms"; - -/** - * Query atom for fetching all NewLLMConfigs for the active search space - */ -export const newLLMConfigsAtom = atomWithQuery((get) => { - const searchSpaceId = get(activeSearchSpaceIdAtom); - - return { - queryKey: cacheKeys.newLLMConfigs.all(Number(searchSpaceId)), - enabled: !!searchSpaceId, - staleTime: 5 * 60 * 1000, // 5 minutes - queryFn: async () => { - return newLLMConfigApiService.getConfigs({ - search_space_id: Number(searchSpaceId), - }); - }, - }; -}); - -/** - * Query atom for fetching global NewLLMConfigs (from YAML, negative IDs) - */ -export const globalNewLLMConfigsAtom = atomWithQuery(() => { - return { - queryKey: cacheKeys.newLLMConfigs.global(), - staleTime: 10 * 60 * 1000, // 10 minutes - global configs rarely change - enabled: !!getBearerToken(), - queryFn: async () => { - return newLLMConfigApiService.getGlobalConfigs(); - }, - }; -}); - -/** - * Query atom for fetching LLM preferences (role assignments) for the active search space - */ -export const llmPreferencesAtom = atomWithQuery((get) => { - const searchSpaceId = get(activeSearchSpaceIdAtom); - - return { - queryKey: cacheKeys.newLLMConfigs.preferences(Number(searchSpaceId)), - enabled: !!searchSpaceId, - staleTime: 5 * 60 * 1000, // 5 minutes - queryFn: async () => { - return newLLMConfigApiService.getLLMPreferences(Number(searchSpaceId)); - }, - }; -}); - -/** - * Query atom for fetching default system instructions template - */ -export const defaultSystemInstructionsAtom = atomWithQuery(() => { - return { - queryKey: cacheKeys.newLLMConfigs.defaultInstructions(), - staleTime: 60 * 60 * 1000, // 1 hour - this rarely changes - queryFn: async () => { - return newLLMConfigApiService.getDefaultSystemInstructions(); - }, - }; -}); - -/** - * Query atom for the dynamic model catalogue. - * Fetched from the backend (which proxies OpenRouter's public API). - * Falls back to the static hardcoded list on error. - */ -export const modelListAtom = atomWithQuery(() => { - return { - queryKey: cacheKeys.newLLMConfigs.modelList(), - staleTime: 60 * 60 * 1000, // 1 hour - models don't change often - placeholderData: LLM_MODELS, - queryFn: async (): Promise => { - const data = await newLLMConfigApiService.getModels(); - const dynamicModels = data.map((m) => ({ - value: m.value, - label: m.label, - provider: m.provider, - contextWindow: m.context_window ?? undefined, - })); - - // Providers covered by the dynamic API (from OpenRouter mapping). - // For uncovered providers (Ollama, Groq, Bedrock, etc.) keep the - // hand-curated static suggestions so users still see model options. - const coveredProviders = new Set(dynamicModels.map((m) => m.provider)); - const staticFallbacks = LLM_MODELS.filter((m) => !coveredProviders.has(m.provider)); - - return [...dynamicModels, ...staticFallbacks]; - }, - }; -}); diff --git a/surfsense_web/atoms/vision-llm-config/vision-llm-config-mutation.atoms.ts b/surfsense_web/atoms/vision-llm-config/vision-llm-config-mutation.atoms.ts deleted file mode 100644 index f46b977d5..000000000 --- a/surfsense_web/atoms/vision-llm-config/vision-llm-config-mutation.atoms.ts +++ /dev/null @@ -1,87 +0,0 @@ -import { atomWithMutation } from "jotai-tanstack-query"; -import { toast } from "sonner"; -import type { - CreateVisionLLMConfigRequest, - CreateVisionLLMConfigResponse, - DeleteVisionLLMConfigResponse, - GetVisionLLMConfigsResponse, - UpdateVisionLLMConfigRequest, - UpdateVisionLLMConfigResponse, -} from "@/contracts/types/new-llm-config.types"; -import { visionLLMConfigApiService } from "@/lib/apis/vision-llm-config-api.service"; -import { cacheKeys } from "@/lib/query-client/cache-keys"; -import { queryClient } from "@/lib/query-client/client"; -import { activeSearchSpaceIdAtom } from "../search-spaces/search-space-query.atoms"; - -export const createVisionLLMConfigMutationAtom = atomWithMutation((get) => { - const searchSpaceId = get(activeSearchSpaceIdAtom); - - return { - mutationKey: ["vision-llm-configs", "create"], - meta: { suppressGlobalErrorToast: true }, - enabled: !!searchSpaceId, - mutationFn: async (request: CreateVisionLLMConfigRequest) => { - return visionLLMConfigApiService.createConfig(request); - }, - onSuccess: (_: CreateVisionLLMConfigResponse, request: CreateVisionLLMConfigRequest) => { - toast.success(`${request.name} created`); - queryClient.invalidateQueries({ - queryKey: cacheKeys.visionLLMConfigs.all(Number(searchSpaceId)), - }); - }, - onError: (error: Error) => { - toast.error(error.message || "Failed to create vision model"); - }, - }; -}); - -export const updateVisionLLMConfigMutationAtom = atomWithMutation((get) => { - const searchSpaceId = get(activeSearchSpaceIdAtom); - - return { - mutationKey: ["vision-llm-configs", "update"], - meta: { suppressGlobalErrorToast: true }, - enabled: !!searchSpaceId, - mutationFn: async (request: UpdateVisionLLMConfigRequest) => { - return visionLLMConfigApiService.updateConfig(request); - }, - onSuccess: (_: UpdateVisionLLMConfigResponse, request: UpdateVisionLLMConfigRequest) => { - toast.success(`${request.data.name ?? "Configuration"} updated`); - queryClient.invalidateQueries({ - queryKey: cacheKeys.visionLLMConfigs.all(Number(searchSpaceId)), - }); - queryClient.invalidateQueries({ - queryKey: cacheKeys.visionLLMConfigs.byId(request.id), - }); - }, - onError: (error: Error) => { - toast.error(error.message || "Failed to update vision model"); - }, - }; -}); - -export const deleteVisionLLMConfigMutationAtom = atomWithMutation((get) => { - const searchSpaceId = get(activeSearchSpaceIdAtom); - - return { - mutationKey: ["vision-llm-configs", "delete"], - meta: { suppressGlobalErrorToast: true }, - enabled: !!searchSpaceId, - mutationFn: async (request: { id: number; name: string }) => { - return visionLLMConfigApiService.deleteConfig(request.id); - }, - onSuccess: (_: DeleteVisionLLMConfigResponse, request: { id: number; name: string }) => { - toast.success(`${request.name} deleted`); - queryClient.setQueryData( - cacheKeys.visionLLMConfigs.all(Number(searchSpaceId)), - (oldData: GetVisionLLMConfigsResponse | undefined) => { - if (!oldData) return oldData; - return oldData.filter((config) => config.id !== request.id); - } - ); - }, - onError: (error: Error) => { - toast.error(error.message || "Failed to delete vision model"); - }, - }; -}); diff --git a/surfsense_web/atoms/vision-llm-config/vision-llm-config-query.atoms.ts b/surfsense_web/atoms/vision-llm-config/vision-llm-config-query.atoms.ts deleted file mode 100644 index 906ce638f..000000000 --- a/surfsense_web/atoms/vision-llm-config/vision-llm-config-query.atoms.ts +++ /dev/null @@ -1,51 +0,0 @@ -import { atomWithQuery } from "jotai-tanstack-query"; -import type { LLMModel } from "@/contracts/enums/llm-models"; -import { VISION_MODELS } from "@/contracts/enums/vision-providers"; -import { visionLLMConfigApiService } from "@/lib/apis/vision-llm-config-api.service"; -import { cacheKeys } from "@/lib/query-client/cache-keys"; -import { activeSearchSpaceIdAtom } from "../search-spaces/search-space-query.atoms"; - -export const visionLLMConfigsAtom = atomWithQuery((get) => { - const searchSpaceId = get(activeSearchSpaceIdAtom); - - return { - queryKey: cacheKeys.visionLLMConfigs.all(Number(searchSpaceId)), - enabled: !!searchSpaceId, - staleTime: 5 * 60 * 1000, - queryFn: async () => { - return visionLLMConfigApiService.getConfigs(Number(searchSpaceId)); - }, - }; -}); - -export const globalVisionLLMConfigsAtom = atomWithQuery(() => { - return { - queryKey: cacheKeys.visionLLMConfigs.global(), - staleTime: 10 * 60 * 1000, - queryFn: async () => { - return visionLLMConfigApiService.getGlobalConfigs(); - }, - }; -}); - -export const visionModelListAtom = atomWithQuery(() => { - return { - queryKey: cacheKeys.visionLLMConfigs.modelList(), - staleTime: 60 * 60 * 1000, - placeholderData: VISION_MODELS, - queryFn: async (): Promise => { - const data = await visionLLMConfigApiService.getModels(); - const dynamicModels = data.map((m) => ({ - value: m.value, - label: m.label, - provider: m.provider, - contextWindow: m.context_window ?? undefined, - })); - - const coveredProviders = new Set(dynamicModels.map((m) => m.provider)); - const staticFallbacks = VISION_MODELS.filter((m) => !coveredProviders.has(m.provider)); - - return [...dynamicModels, ...staticFallbacks]; - }, - }; -}); diff --git a/surfsense_web/components/new-chat/chat-header.tsx b/surfsense_web/components/new-chat/chat-header.tsx index 4716418ee..d65dc93a7 100644 --- a/surfsense_web/components/new-chat/chat-header.tsx +++ b/surfsense_web/components/new-chat/chat-header.tsx @@ -1,17 +1,5 @@ "use client"; -import { useCallback, useState } from "react"; -import { ImageConfigDialog } from "@/components/shared/image-config-dialog"; -import { ModelConfigDialog } from "@/components/shared/model-config-dialog"; -import { VisionConfigDialog } from "@/components/shared/vision-config-dialog"; -import type { - GlobalImageGenConfig, - GlobalNewLLMConfig, - GlobalVisionLLMConfig, - ImageGenerationConfig, - NewLLMConfigPublic, - VisionLLMConfig, -} from "@/contracts/types/new-llm-config.types"; import { ModelSelector } from "./model-selector"; interface ChatHeaderProps { @@ -20,148 +8,9 @@ interface ChatHeaderProps { } export function ChatHeader({ searchSpaceId, className }: ChatHeaderProps) { - // LLM config dialog state - const [dialogOpen, setDialogOpen] = useState(false); - const [selectedConfig, setSelectedConfig] = useState< - NewLLMConfigPublic | GlobalNewLLMConfig | null - >(null); - const [isGlobal, setIsGlobal] = useState(false); - const [dialogMode, setDialogMode] = useState<"create" | "edit" | "view">("view"); - - // Image config dialog state - const [imageDialogOpen, setImageDialogOpen] = useState(false); - const [selectedImageConfig, setSelectedImageConfig] = useState< - ImageGenerationConfig | GlobalImageGenConfig | null - >(null); - const [isImageGlobal, setIsImageGlobal] = useState(false); - const [imageDialogMode, setImageDialogMode] = useState<"create" | "edit" | "view">("view"); - - // Vision config dialog state - const [visionDialogOpen, setVisionDialogOpen] = useState(false); - const [selectedVisionConfig, setSelectedVisionConfig] = useState< - VisionLLMConfig | GlobalVisionLLMConfig | null - >(null); - const [isVisionGlobal, setIsVisionGlobal] = useState(false); - const [visionDialogMode, setVisionDialogMode] = useState<"create" | "edit" | "view">("view"); - - // Default provider for create dialogs - const [defaultLLMProvider, setDefaultLLMProvider] = useState(); - const [defaultImageProvider, setDefaultImageProvider] = useState(); - const [defaultVisionProvider, setDefaultVisionProvider] = useState(); - - // LLM handlers - const handleEditLLMConfig = useCallback( - (config: NewLLMConfigPublic | GlobalNewLLMConfig, global: boolean) => { - setSelectedConfig(config); - setIsGlobal(global); - setDialogMode(global ? "view" : "edit"); - setDefaultLLMProvider(undefined); - setDialogOpen(true); - }, - [] - ); - - const handleAddNewLLM = useCallback((provider?: string) => { - setSelectedConfig(null); - setIsGlobal(false); - setDialogMode("create"); - setDefaultLLMProvider(provider); - setDialogOpen(true); - }, []); - - const handleDialogClose = useCallback((open: boolean) => { - setDialogOpen(open); - if (!open) setSelectedConfig(null); - }, []); - - // Image model handlers - const handleAddImageModel = useCallback((provider?: string) => { - setSelectedImageConfig(null); - setIsImageGlobal(false); - setImageDialogMode("create"); - setDefaultImageProvider(provider); - setImageDialogOpen(true); - }, []); - - const handleEditImageConfig = useCallback( - (config: ImageGenerationConfig | GlobalImageGenConfig, global: boolean) => { - setSelectedImageConfig(config); - setIsImageGlobal(global); - setImageDialogMode(global ? "view" : "edit"); - setDefaultImageProvider(undefined); - setImageDialogOpen(true); - }, - [] - ); - - const handleImageDialogClose = useCallback((open: boolean) => { - setImageDialogOpen(open); - if (!open) setSelectedImageConfig(null); - }, []); - - // Vision model handlers - const handleAddVisionModel = useCallback((provider?: string) => { - setSelectedVisionConfig(null); - setIsVisionGlobal(false); - setVisionDialogMode("create"); - setDefaultVisionProvider(provider); - setVisionDialogOpen(true); - }, []); - - const handleEditVisionConfig = useCallback( - (config: VisionLLMConfig | GlobalVisionLLMConfig, global: boolean) => { - setSelectedVisionConfig(config); - setIsVisionGlobal(global); - setVisionDialogMode(global ? "view" : "edit"); - setDefaultVisionProvider(undefined); - setVisionDialogOpen(true); - }, - [] - ); - - const handleVisionDialogClose = useCallback((open: boolean) => { - setVisionDialogOpen(open); - if (!open) setSelectedVisionConfig(null); - }, []); - return (
- - - - +
); } diff --git a/surfsense_web/components/settings/agent-model-manager.tsx b/surfsense_web/components/settings/agent-model-manager.tsx deleted file mode 100644 index 507a263e0..000000000 --- a/surfsense_web/components/settings/agent-model-manager.tsx +++ /dev/null @@ -1,423 +0,0 @@ -"use client"; - -import { useAtomValue } from "jotai"; -import { AlertCircle, Dot, FileText, Info, Pencil, RefreshCw, Trash2 } from "lucide-react"; -import { useMemo, useState } from "react"; -import { membersAtom, myAccessAtom } from "@/atoms/members/members-query.atoms"; -import { deleteNewLLMConfigMutationAtom } from "@/atoms/new-llm-config/new-llm-config-mutation.atoms"; -import { - globalNewLLMConfigsAtom, - newLLMConfigsAtom, -} from "@/atoms/new-llm-config/new-llm-config-query.atoms"; -import { ModelConfigDialog } from "@/components/shared/model-config-dialog"; -import { Alert, AlertDescription } from "@/components/ui/alert"; -import { - AlertDialog, - AlertDialogAction, - AlertDialogCancel, - AlertDialogContent, - AlertDialogDescription, - AlertDialogFooter, - AlertDialogHeader, - AlertDialogTitle, -} from "@/components/ui/alert-dialog"; -import { Avatar, AvatarFallback, AvatarImage } from "@/components/ui/avatar"; -import { Badge } from "@/components/ui/badge"; -import { Button } from "@/components/ui/button"; -import { Card, CardContent } from "@/components/ui/card"; -import { Separator } from "@/components/ui/separator"; -import { Skeleton } from "@/components/ui/skeleton"; -import { Spinner } from "@/components/ui/spinner"; -import { Tooltip, TooltipContent, TooltipProvider, TooltipTrigger } from "@/components/ui/tooltip"; -import type { NewLLMConfig } from "@/contracts/types/new-llm-config.types"; -import { useMediaQuery } from "@/hooks/use-media-query"; -import { getProviderIcon } from "@/lib/provider-icons"; -import { cn } from "@/lib/utils"; - -interface AgentModelManagerProps { - searchSpaceId: number; -} - -function getInitials(name: string): string { - const parts = name.trim().split(/\s+/); - if (parts.length >= 2) { - return (parts[0][0] + parts[1][0]).toUpperCase(); - } - return name.slice(0, 2).toUpperCase(); -} - -export function AgentModelManager({ searchSpaceId }: AgentModelManagerProps) { - const isDesktop = useMediaQuery("(min-width: 768px)"); - // Mutations - const { mutateAsync: deleteConfig, isPending: isDeleting } = useAtomValue( - deleteNewLLMConfigMutationAtom - ); - - // Queries - const { - data: configs, - isFetching: isLoading, - error: fetchError, - refetch: refreshConfigs, - } = useAtomValue(newLLMConfigsAtom); - const { data: globalConfigs = [] } = useAtomValue(globalNewLLMConfigsAtom); - - // Members for user resolution - const { data: members } = useAtomValue(membersAtom); - const memberMap = useMemo(() => { - const map = new Map(); - if (members) { - for (const m of members) { - map.set(m.user_id, { - name: m.user_display_name || m.user_email || "Unknown", - email: m.user_email || undefined, - avatarUrl: m.user_avatar_url || undefined, - }); - } - } - return map; - }, [members]); - - // Permissions - const { data: access } = useAtomValue(myAccessAtom); - const canCreate = - !!access && (access.is_owner || (access.permissions?.includes("llm_configs:create") ?? false)); - const canUpdate = - !!access && (access.is_owner || (access.permissions?.includes("llm_configs:update") ?? false)); - const canDelete = - !!access && (access.is_owner || (access.permissions?.includes("llm_configs:delete") ?? false)); - const isReadOnly = !canCreate && !canUpdate && !canDelete; - - // Local state - const [isDialogOpen, setIsDialogOpen] = useState(false); - const [editingConfig, setEditingConfig] = useState(null); - const [configToDelete, setConfigToDelete] = useState(null); - - const handleDelete = async () => { - if (!configToDelete) return; - try { - await deleteConfig({ id: configToDelete.id, name: configToDelete.name }); - setConfigToDelete(null); - } catch { - // Error handled by mutation state - } - }; - - const openEditDialog = (config: NewLLMConfig) => { - setEditingConfig(config); - setIsDialogOpen(true); - }; - - const openNewDialog = () => { - setEditingConfig(null); - setIsDialogOpen(true); - }; - - return ( -
- {/* Header actions */} -
- - {canCreate && ( - - )} -
- - {/* Fetch Error Alert */} - {fetchError && ( -
- - - - {fetchError?.message ?? "Failed to load configurations"} - - -
- )} - - {/* Read-only / Limited permissions notice */} - {access && !isLoading && isReadOnly && ( -
- - - -

- You have read-only access to LLM - configurations. Contact a space owner to request additional permissions. -

-
-
-
- )} - {access && !isLoading && !isReadOnly && (!canCreate || !canUpdate || !canDelete) && ( -
- - - -

- You can{" "} - {[canCreate && "create", canUpdate && "edit", canDelete && "delete"] - .filter(Boolean) - .join(" and ")}{" "} - configurations - {!canDelete && ", but cannot delete them"}. -

-
-
-
- )} - - {/* Global Configs Info */} - {(isLoading || globalConfigs.length > 0) && ( - - - - {isLoading ? ( -
- -
- ) : ( -

- - {globalConfigs.length} global {globalConfigs.length === 1 ? "model" : "models"} - {" "} - available from your administrator. -

- )} -
-
- )} - - {/* Loading Skeleton */} - {isLoading && ( -
- {["skeleton-a", "skeleton-b", "skeleton-c"].map((key) => ( - - - - - - - - ))} -
- )} - - {/* Configurations List */} - {!isLoading && ( -
- {configs?.length === 0 ? ( -
- - -

No Models Yet

-

- {canCreate - ? "Add your first model to power chat, reports, and other agent capabilities" - : "No models have been added to this space yet. Contact a space owner to add one"} -

-
-
-
- ) : ( -
- {configs?.map((config) => { - const member = config.user_id ? memberMap.get(config.user_id) : null; - - return ( -
- - - {/* Header: Icon + Name + Actions */} -
-
-
- {getProviderIcon(config.provider, { className: "size-4" })} -
-
-

- {config.name} -

- {config.description && ( -

- {config.description} -

- )} -
-
- {(canUpdate || canDelete) && ( -
- {canUpdate && ( - - - - - - Edit - - - )} - {canDelete && ( - - - - - - Delete - - - )} -
- )} -
- - {/* Feature badges */} -
- {config.citations_enabled && ( - - Citations - - )} - {!config.use_default_system_instructions && - config.system_instructions && ( - - - Custom - - )} -
- - {/* Footer: Date + Creator */} -
- -
- - {new Date(config.created_at).toLocaleDateString(undefined, { - year: "numeric", - month: "short", - day: "numeric", - })} - - {member && ( - <> - - - - -
- - {member.avatarUrl && ( - - )} - - {getInitials(member.name)} - - - - {member.name} - -
-
- - {member.email || member.name} - -
-
- - )} -
-
-
-
-
- ); - })} -
- )} -
- )} - - {/* Add/Edit Configuration Dialog */} - { - setIsDialogOpen(open); - if (!open) setEditingConfig(null); - }} - config={editingConfig} - isGlobal={false} - searchSpaceId={searchSpaceId} - mode={editingConfig ? "edit" : "create"} - /> - - {/* Delete Confirmation Dialog */} - !open && setConfigToDelete(null)} - > - - - Delete Model - - Are you sure you want to delete{" "} - {configToDelete?.name}? This - action cannot be undone. - - - - Cancel - - {isDeleting ? ( - <> - - Deleting - - ) : ( - "Delete" - )} - - - - -
- ); -} diff --git a/surfsense_web/components/settings/image-model-manager.tsx b/surfsense_web/components/settings/image-model-manager.tsx deleted file mode 100644 index 494f7aae9..000000000 --- a/surfsense_web/components/settings/image-model-manager.tsx +++ /dev/null @@ -1,489 +0,0 @@ -"use client"; - -import { useAtomValue } from "jotai"; -import { AlertCircle, Dot, Info, Pencil, RefreshCw, Trash2 } from "lucide-react"; -import { useMemo, useState } from "react"; -import { deleteImageGenConfigMutationAtom } from "@/atoms/image-gen-config/image-gen-config-mutation.atoms"; -import { - globalImageGenConfigsAtom, - imageGenConfigsAtom, -} from "@/atoms/image-gen-config/image-gen-config-query.atoms"; -import { membersAtom, myAccessAtom } from "@/atoms/members/members-query.atoms"; -import { ImageConfigDialog } from "@/components/shared/image-config-dialog"; -import { Alert, AlertDescription } from "@/components/ui/alert"; -import { - AlertDialog, - AlertDialogAction, - AlertDialogCancel, - AlertDialogContent, - AlertDialogDescription, - AlertDialogFooter, - AlertDialogHeader, - AlertDialogTitle, -} from "@/components/ui/alert-dialog"; -import { Avatar, AvatarFallback, AvatarImage } from "@/components/ui/avatar"; -import { Badge } from "@/components/ui/badge"; -import { Button } from "@/components/ui/button"; -import { Card, CardContent } from "@/components/ui/card"; -import { Separator } from "@/components/ui/separator"; -import { Skeleton } from "@/components/ui/skeleton"; -import { Spinner } from "@/components/ui/spinner"; -import { Tooltip, TooltipContent, TooltipProvider, TooltipTrigger } from "@/components/ui/tooltip"; -import type { ImageGenerationConfig } from "@/contracts/types/new-llm-config.types"; -import { useMediaQuery } from "@/hooks/use-media-query"; -import { getProviderIcon } from "@/lib/provider-icons"; -import { cn } from "@/lib/utils"; - -interface ImageModelManagerProps { - searchSpaceId: number; -} - -function getInitials(name: string): string { - const parts = name.trim().split(/\s+/); - if (parts.length >= 2) { - return (parts[0][0] + parts[1][0]).toUpperCase(); - } - return name.slice(0, 2).toUpperCase(); -} - -export function ImageModelManager({ searchSpaceId }: ImageModelManagerProps) { - const isDesktop = useMediaQuery("(min-width: 768px)"); - - const { - mutateAsync: deleteConfig, - isPending: isDeleting, - error: deleteError, - } = useAtomValue(deleteImageGenConfigMutationAtom); - - const { - data: userConfigs, - isFetching: configsLoading, - error: fetchError, - refetch: refreshConfigs, - } = useAtomValue(imageGenConfigsAtom); - const { data: globalConfigs = [], isFetching: globalLoading } = - useAtomValue(globalImageGenConfigsAtom); - - const { data: members } = useAtomValue(membersAtom); - const memberMap = useMemo(() => { - const map = new Map(); - if (members) { - for (const m of members) { - map.set(m.user_id, { - name: m.user_display_name || m.user_email || "Unknown", - email: m.user_email || undefined, - avatarUrl: m.user_avatar_url || undefined, - }); - } - } - return map; - }, [members]); - - const { data: access } = useAtomValue(myAccessAtom); - const canCreate = - !!access && - (access.is_owner || (access.permissions?.includes("image_generations:create") ?? false)); - const canDelete = - !!access && - (access.is_owner || (access.permissions?.includes("image_generations:delete") ?? false)); - const canUpdate = canCreate; - const isReadOnly = !canCreate && !canDelete; - - const [isDialogOpen, setIsDialogOpen] = useState(false); - const [editingConfig, setEditingConfig] = useState(null); - const [configToDelete, setConfigToDelete] = useState(null); - - const isLoading = configsLoading || globalLoading; - const errors = [deleteError, fetchError].filter(Boolean) as Error[]; - - const openEditDialog = (config: ImageGenerationConfig) => { - setEditingConfig(config); - setIsDialogOpen(true); - }; - - const openNewDialog = () => { - setEditingConfig(null); - setIsDialogOpen(true); - }; - - const handleDelete = async () => { - if (!configToDelete) return; - try { - await deleteConfig({ id: configToDelete.id, name: configToDelete.name }); - setConfigToDelete(null); - } catch { - // Error handled by mutation - } - }; - - return ( -
- {/* Header actions */} -
- - {canCreate && ( - - )} -
- - {/* Errors */} - {errors.map((err) => ( -
- - - {err?.message} - -
- ))} - - {/* Read-only / Limited permissions notice */} - {access && !isLoading && isReadOnly && ( -
- - - -

- You have read-only access to image generation - configurations. Contact a space owner to request additional permissions. -

-
-
-
- )} - {access && !isLoading && !isReadOnly && (!canCreate || !canDelete) && ( -
- - - -

- You can{" "} - {[canCreate && "create and edit", canDelete && "delete"] - .filter(Boolean) - .join(" and ")}{" "} - image model configurations - {!canDelete && ", but cannot delete them"}. -

-
-
-
- )} - - {/* Global info */} - {(isLoading || - globalConfigs.filter((g) => !("is_auto_mode" in g && g.is_auto_mode)).length > 0) && ( - - - - {isLoading ? ( -
- -
- ) : ( -

- - {globalConfigs.filter((g) => !("is_auto_mode" in g && g.is_auto_mode)).length}{" "} - global image{" "} - {globalConfigs.filter((g) => !("is_auto_mode" in g && g.is_auto_mode)).length === - 1 - ? "model" - : "models"} - {" "} - available from your administrator. {(() => { - const nonAuto = globalConfigs.filter( - (g) => !("is_auto_mode" in g && g.is_auto_mode) - ); - const premium = nonAuto.filter( - (g) => - "billing_tier" in g && - (g as { billing_tier?: string }).billing_tier === "premium" - ).length; - const free = nonAuto.length - premium; - if (premium > 0 && free > 0) { - return `${premium} premium, ${free} free.`; - } - if (premium > 0) { - return `All ${premium} premium — debits your shared credit pool.`; - } - return `All ${free} free.`; - })()} -

- )} -
-
- )} - - {/* Global Image Models — read-only cards with per-model Free/Premium - badges. Mirrors the badge palette used by the chat role selector - (`llm-role-manager.tsx`) so the meaning is consistent across - every model-configuration surface (chat / image / vision). */} - {!isLoading && - globalConfigs.filter((g) => !("is_auto_mode" in g && g.is_auto_mode)).length > 0 && ( -
-
- {globalConfigs - .filter((g) => !("is_auto_mode" in g && g.is_auto_mode)) - .map((cfg) => { - const billingTier = - ("billing_tier" in cfg && - typeof (cfg as { billing_tier?: string }).billing_tier === "string" && - (cfg as { billing_tier?: string }).billing_tier) || - "free"; - const isPremium = billingTier === "premium"; - return ( - - -
-
- {getProviderIcon(cfg.provider, { className: "size-4" })} -
-
-

- {cfg.name} -

- {isPremium ? ( - - Premium - - ) : ( - - Free - - )} -
-
- {cfg.description && ( -

- {cfg.description} -

- )} -
- -
- - {cfg.model_name} - -
-
-
-
- ); - })} -
-
- )} - - {/* Loading Skeleton */} - {isLoading && ( -
-
-
- {["skeleton-a", "skeleton-b", "skeleton-c"].map((key) => ( - - - - - - - - ))} -
-
-
- )} - - {/* User Configs */} - {!isLoading && ( -
- {(userConfigs?.length ?? 0) === 0 ? ( - - -

No Image Models Yet

-

- {canCreate - ? "Add your own image generation model (DALL-E 3, GPT Image 1, etc.)" - : "No image models have been added to this space yet. Contact a space owner to add one."} -

-
-
- ) : ( -
- {userConfigs?.map((config) => { - const member = config.user_id ? memberMap.get(config.user_id) : null; - - return ( -
- - - {/* Header: Icon + Name + Actions */} -
-
-
- {getProviderIcon(config.provider, { className: "size-4" })} -
-
-

- {config.name} -

- {config.description && ( -

- {config.description} -

- )} -
-
- {(canUpdate || canDelete) && ( -
- {canUpdate && ( - - - - - - Edit - - - )} - {canDelete && ( - - - - - - Delete - - - )} -
- )} -
- - {/* Footer: Date + Creator */} -
- -
- - {new Date(config.created_at).toLocaleDateString(undefined, { - year: "numeric", - month: "short", - day: "numeric", - })} - - {member && ( - <> - - - - -
- - {member.avatarUrl && ( - - )} - - {getInitials(member.name)} - - - - {member.name} - -
-
- - {member.email || member.name} - -
-
- - )} -
-
-
-
-
- ); - })} -
- )} -
- )} - - {/* Create/Edit Dialog — shared component */} - { - setIsDialogOpen(open); - if (!open) setEditingConfig(null); - }} - config={editingConfig} - isGlobal={false} - searchSpaceId={searchSpaceId} - mode={editingConfig ? "edit" : "create"} - /> - - {/* Delete Confirmation */} - !open && setConfigToDelete(null)} - > - - - Delete Image Model - - Are you sure you want to delete{" "} - {configToDelete?.name}? - - - - Cancel - - Delete - {isDeleting && } - - - - -
- ); -} diff --git a/surfsense_web/components/settings/llm-role-manager.tsx b/surfsense_web/components/settings/llm-role-manager.tsx deleted file mode 100644 index 547675927..000000000 --- a/surfsense_web/components/settings/llm-role-manager.tsx +++ /dev/null @@ -1,443 +0,0 @@ -"use client"; - -import { useAtomValue } from "jotai"; -import { - AlertCircle, - Bot, - CircleCheck, - CircleDashed, - FileText, - ImageIcon, - RefreshCw, - ScanEye, -} from "lucide-react"; -import { useCallback, useEffect, useState } from "react"; -import { toast } from "sonner"; -import { - globalImageGenConfigsAtom, - imageGenConfigsAtom, -} from "@/atoms/image-gen-config/image-gen-config-query.atoms"; -import { updateLLMPreferencesMutationAtom } from "@/atoms/new-llm-config/new-llm-config-mutation.atoms"; -import { - globalNewLLMConfigsAtom, - llmPreferencesAtom, - newLLMConfigsAtom, -} from "@/atoms/new-llm-config/new-llm-config-query.atoms"; -import { - globalVisionLLMConfigsAtom, - visionLLMConfigsAtom, -} from "@/atoms/vision-llm-config/vision-llm-config-query.atoms"; -import { Alert, AlertDescription } from "@/components/ui/alert"; -import { Badge } from "@/components/ui/badge"; -import { Button } from "@/components/ui/button"; -import { Card, CardContent } from "@/components/ui/card"; -import { Label } from "@/components/ui/label"; -import { - Select, - SelectContent, - SelectGroup, - SelectItem, - SelectLabel, - SelectTrigger, - SelectValue, -} from "@/components/ui/select"; -import { Skeleton } from "@/components/ui/skeleton"; -import { Spinner } from "@/components/ui/spinner"; -import { cn } from "@/lib/utils"; - -const ROLE_DESCRIPTIONS = { - agent: { - icon: Bot, - title: "Chat model", - description: "Primary model for chat interactions and agent operations", - color: "text-muted-foreground", - bgColor: "bg-muted", - prefKey: "agent_llm_id" as const, - configType: "llm" as const, - }, - image_generation: { - icon: ImageIcon, - title: "Image Generation Model", - description: "Model used for AI image generation (DALL-E, GPT Image, etc.)", - color: "text-muted-foreground", - bgColor: "bg-muted", - prefKey: "image_generation_config_id" as const, - configType: "image" as const, - }, - vision: { - icon: ScanEye, - title: "Vision LLM", - description: "Vision-capable model for screenshot analysis and context extraction", - color: "text-muted-foreground", - bgColor: "bg-muted", - prefKey: "vision_llm_config_id" as const, - configType: "vision" as const, - }, -}; - -interface LLMRoleManagerProps { - searchSpaceId: number; -} - -export function LLMRoleManager({ searchSpaceId }: LLMRoleManagerProps) { - // LLM configs - const { - data: newLLMConfigs = [], - isFetching: configsLoading, - error: configsError, - refetch: refreshConfigs, - } = useAtomValue(newLLMConfigsAtom); - const { - data: globalConfigs = [], - isFetching: globalConfigsLoading, - error: globalConfigsError, - } = useAtomValue(globalNewLLMConfigsAtom); - - // Image gen configs - const { - data: userImageConfigs = [], - isFetching: imageConfigsLoading, - error: imageConfigsError, - } = useAtomValue(imageGenConfigsAtom); - const { - data: globalImageConfigs = [], - isFetching: globalImageConfigsLoading, - error: globalImageConfigsError, - } = useAtomValue(globalImageGenConfigsAtom); - - // Vision LLM configs - const { - data: userVisionConfigs = [], - isFetching: visionConfigsLoading, - error: visionConfigsError, - } = useAtomValue(visionLLMConfigsAtom); - const { - data: globalVisionConfigs = [], - isFetching: globalVisionConfigsLoading, - error: globalVisionConfigsError, - } = useAtomValue(globalVisionLLMConfigsAtom); - - // Preferences - const { - data: preferences = {}, - isFetching: preferencesLoading, - error: preferencesError, - } = useAtomValue(llmPreferencesAtom); - - const { mutateAsync: updatePreferences } = useAtomValue(updateLLMPreferencesMutationAtom); - - const [assignments, setAssignments] = useState>(() => ({ - agent_llm_id: preferences.agent_llm_id ?? null, - image_generation_config_id: preferences.image_generation_config_id ?? null, - vision_llm_config_id: preferences.vision_llm_config_id ?? null, - })); - - // Sync local state when preferences load/change. Without this, the selects - // stay on their initial (often empty) value while the query is in flight, - // so a saved assignment — including Auto mode (id 0) — never appears. - useEffect(() => { - setAssignments({ - agent_llm_id: preferences.agent_llm_id ?? null, - image_generation_config_id: preferences.image_generation_config_id ?? null, - vision_llm_config_id: preferences.vision_llm_config_id ?? null, - }); - }, [ - preferences.agent_llm_id, - preferences.image_generation_config_id, - preferences.vision_llm_config_id, - ]); - - const [savingRole, setSavingRole] = useState(null); - - const handleRoleAssignment = useCallback( - async (prefKey: string, configId: string) => { - // "unassigned" clears the role (null). Every other option — including - // Auto mode, whose config id is 0 — must be sent as-is. Using a falsy - // check here (e.g. `value || undefined`) would drop id 0 and silently - // fail to persist Auto mode. - const value = configId === "unassigned" ? null : Number(configId); - - setAssignments((prev) => ({ ...prev, [prefKey]: value })); - setSavingRole(prefKey); - - try { - await updatePreferences({ - search_space_id: searchSpaceId, - data: { [prefKey]: value }, - }); - toast.success("Role assignment updated"); - } finally { - setSavingRole(null); - } - }, - [updatePreferences, searchSpaceId] - ); - - // Combine global and custom LLM configs - const allLLMConfigs = [ - ...globalConfigs.map((config) => ({ ...config, is_global: true })), - ...newLLMConfigs.filter((config) => config.id && config.id.toString().trim() !== ""), - ]; - - // Combine global and custom image gen configs - const allImageConfigs = [ - ...globalImageConfigs.map((config) => ({ ...config, is_global: true })), - ...(userImageConfigs ?? []).filter((config) => config.id && config.id.toString().trim() !== ""), - ]; - - // Combine global and custom vision LLM configs - const allVisionConfigs = [ - ...globalVisionConfigs.map((config) => ({ ...config, is_global: true })), - ...(userVisionConfigs ?? []).filter( - (config) => config.id && config.id.toString().trim() !== "" - ), - ]; - - const isLoading = - configsLoading || - preferencesLoading || - globalConfigsLoading || - imageConfigsLoading || - globalImageConfigsLoading || - visionConfigsLoading || - globalVisionConfigsLoading; - const hasError = - configsError || - preferencesError || - globalConfigsError || - imageConfigsError || - globalImageConfigsError || - visionConfigsError || - globalVisionConfigsError; - const hasAnyConfigs = allLLMConfigs.length > 0 || allImageConfigs.length > 0; - - return ( -
- {/* Header actions */} -
- -
- - {/* Error Alert */} - {hasError && ( -
- - - - {(configsError?.message ?? "Failed to load LLM configurations") || - (preferencesError?.message ?? "Failed to load preferences") || - (globalConfigsError?.message ?? "Failed to load global configurations")} - - -
- )} - - {/* Loading Skeleton */} - {isLoading && ( -
- {["skeleton-a", "skeleton-b", "skeleton-c"].map((key) => ( - - - - - - - - ))} -
- )} - - {/* No configs warning */} - {!isLoading && !hasError && !hasAnyConfigs && ( - - - - No configurations found. Please add at least one LLM provider or image model in the - respective settings tabs before assigning roles. - - - )} - - {/* Role Assignment Cards */} - {!isLoading && !hasError && hasAnyConfigs && ( -
- {Object.entries(ROLE_DESCRIPTIONS).map(([key, role]) => { - const IconComponent = role.icon; - const currentAssignment = assignments[role.prefKey as keyof typeof assignments]; - - // Pick the right config lists based on role type - const roleGlobalConfigs = - role.configType === "image" - ? globalImageConfigs - : role.configType === "vision" - ? globalVisionConfigs - : globalConfigs; - const roleUserConfigs = - role.configType === "image" - ? (userImageConfigs ?? []).filter((c) => c.id && c.id.toString().trim() !== "") - : role.configType === "vision" - ? (userVisionConfigs ?? []).filter((c) => c.id && c.id.toString().trim() !== "") - : newLLMConfigs.filter((c) => c.id && c.id.toString().trim() !== ""); - const roleAllConfigs = - role.configType === "image" - ? allImageConfigs - : role.configType === "vision" - ? allVisionConfigs - : allLLMConfigs; - - const assignedConfig = roleAllConfigs.find((config) => config.id === currentAssignment); - const isAssigned = !!assignedConfig; - - return ( -
- - - {/* Role Header */} -
-
-
- -
-
-

{role.title}

-

- {role.description} -

-
-
- {savingRole === role.prefKey ? ( - - ) : isAssigned ? ( - - ) : ( - - )} -
- - {/* Selector */} -
- - -
-
-
-
- ); - })} -
- )} -
- ); -} diff --git a/surfsense_web/components/settings/vision-model-manager.tsx b/surfsense_web/components/settings/vision-model-manager.tsx deleted file mode 100644 index 31578b4f1..000000000 --- a/surfsense_web/components/settings/vision-model-manager.tsx +++ /dev/null @@ -1,486 +0,0 @@ -"use client"; - -import { useAtomValue } from "jotai"; -import { AlertCircle, Dot, Info, Pencil, RefreshCw, Trash2 } from "lucide-react"; -import { useMemo, useState } from "react"; -import { membersAtom, myAccessAtom } from "@/atoms/members/members-query.atoms"; -import { deleteVisionLLMConfigMutationAtom } from "@/atoms/vision-llm-config/vision-llm-config-mutation.atoms"; -import { - globalVisionLLMConfigsAtom, - visionLLMConfigsAtom, -} from "@/atoms/vision-llm-config/vision-llm-config-query.atoms"; -import { VisionConfigDialog } from "@/components/shared/vision-config-dialog"; -import { Alert, AlertDescription } from "@/components/ui/alert"; -import { - AlertDialog, - AlertDialogAction, - AlertDialogCancel, - AlertDialogContent, - AlertDialogDescription, - AlertDialogFooter, - AlertDialogHeader, - AlertDialogTitle, -} from "@/components/ui/alert-dialog"; -import { Avatar, AvatarFallback, AvatarImage } from "@/components/ui/avatar"; -import { Badge } from "@/components/ui/badge"; -import { Button } from "@/components/ui/button"; -import { Card, CardContent } from "@/components/ui/card"; -import { Separator } from "@/components/ui/separator"; -import { Skeleton } from "@/components/ui/skeleton"; -import { Spinner } from "@/components/ui/spinner"; -import { Tooltip, TooltipContent, TooltipProvider, TooltipTrigger } from "@/components/ui/tooltip"; -import type { VisionLLMConfig } from "@/contracts/types/new-llm-config.types"; -import { useMediaQuery } from "@/hooks/use-media-query"; -import { getProviderIcon } from "@/lib/provider-icons"; -import { cn } from "@/lib/utils"; - -interface VisionModelManagerProps { - searchSpaceId: number; -} - -function getInitials(name: string): string { - const parts = name.trim().split(/\s+/); - if (parts.length >= 2) { - return (parts[0][0] + parts[1][0]).toUpperCase(); - } - return name.slice(0, 2).toUpperCase(); -} - -export function VisionModelManager({ searchSpaceId }: VisionModelManagerProps) { - const isDesktop = useMediaQuery("(min-width: 768px)"); - - const { - mutateAsync: deleteConfig, - isPending: isDeleting, - error: deleteError, - } = useAtomValue(deleteVisionLLMConfigMutationAtom); - - const { - data: userConfigs, - isFetching: configsLoading, - error: fetchError, - refetch: refreshConfigs, - } = useAtomValue(visionLLMConfigsAtom); - const { data: globalConfigs = [], isFetching: globalLoading } = useAtomValue( - globalVisionLLMConfigsAtom - ); - - const { data: members } = useAtomValue(membersAtom); - const memberMap = useMemo(() => { - const map = new Map(); - if (members) { - for (const m of members) { - map.set(m.user_id, { - name: m.user_display_name || m.user_email || "Unknown", - email: m.user_email || undefined, - avatarUrl: m.user_avatar_url || undefined, - }); - } - } - return map; - }, [members]); - - const { data: access } = useAtomValue(myAccessAtom); - const canCreate = useMemo(() => { - if (!access) return false; - if (access.is_owner) return true; - return access.permissions?.includes("vision_configs:create") ?? false; - }, [access]); - const canDelete = useMemo(() => { - if (!access) return false; - if (access.is_owner) return true; - return access.permissions?.includes("vision_configs:delete") ?? false; - }, [access]); - const canUpdate = canCreate; - const isReadOnly = !canCreate && !canDelete; - - const [isDialogOpen, setIsDialogOpen] = useState(false); - const [editingConfig, setEditingConfig] = useState(null); - const [configToDelete, setConfigToDelete] = useState(null); - - const isLoading = configsLoading || globalLoading; - const errors = [deleteError, fetchError].filter(Boolean) as Error[]; - - const openEditDialog = (config: VisionLLMConfig) => { - setEditingConfig(config); - setIsDialogOpen(true); - }; - - const openNewDialog = () => { - setEditingConfig(null); - setIsDialogOpen(true); - }; - - const handleDelete = async () => { - if (!configToDelete) return; - try { - await deleteConfig({ id: configToDelete.id, name: configToDelete.name }); - setConfigToDelete(null); - } catch { - // Error handled by mutation - } - }; - - return ( -
-
- - {canCreate && ( - - )} -
- - {errors.map((err) => ( -
- - - {err?.message} - -
- ))} - - {access && !isLoading && isReadOnly && ( -
- - - -

- You have read-only access to vision model - configurations. Contact a space owner to request additional permissions. -

-
-
-
- )} - {access && !isLoading && !isReadOnly && (!canCreate || !canDelete) && ( -
- - - -

- You can{" "} - {[canCreate && "create and edit", canDelete && "delete"] - .filter(Boolean) - .join(" and ")}{" "} - vision model configurations - {!canDelete && ", but cannot delete them"}. -

-
-
-
- )} - - {(isLoading || - globalConfigs.filter((g) => !("is_auto_mode" in g && g.is_auto_mode)).length > 0) && ( - - - - {isLoading ? ( -
- -
- ) : ( -

- - {globalConfigs.filter((g) => !("is_auto_mode" in g && g.is_auto_mode)).length}{" "} - global vision{" "} - {globalConfigs.filter((g) => !("is_auto_mode" in g && g.is_auto_mode)).length === - 1 - ? "model" - : "models"} - {" "} - available from your administrator. {(() => { - const nonAuto = globalConfigs.filter( - (g) => !("is_auto_mode" in g && g.is_auto_mode) - ); - const premium = nonAuto.filter( - (g) => - "billing_tier" in g && - (g as { billing_tier?: string }).billing_tier === "premium" - ).length; - const free = nonAuto.length - premium; - if (premium > 0 && free > 0) { - return `${premium} premium, ${free} free.`; - } - if (premium > 0) { - return `All ${premium} premium — debits your shared credit pool.`; - } - return `All ${free} free.`; - })()} -

- )} -
-
- )} - - {/* Global Vision Models — read-only cards with per-model Free/Premium - badges. Mirrors the badge palette used by the chat role selector - (`llm-role-manager.tsx`) so the meaning is consistent across - every model-configuration surface (chat / image / vision). */} - {!isLoading && - globalConfigs.filter((g) => !("is_auto_mode" in g && g.is_auto_mode)).length > 0 && ( -
-
- {globalConfigs - .filter((g) => !("is_auto_mode" in g && g.is_auto_mode)) - .map((cfg) => { - const billingTier = - ("billing_tier" in cfg && - typeof (cfg as { billing_tier?: string }).billing_tier === "string" && - (cfg as { billing_tier?: string }).billing_tier) || - "free"; - const isPremium = billingTier === "premium"; - return ( - - -
-
- {getProviderIcon(cfg.provider, { className: "size-4" })} -
-
-

- {cfg.name} -

- {isPremium ? ( - - Premium - - ) : ( - - Free - - )} -
-
- {cfg.description && ( -

- {cfg.description} -

- )} -
- -
- - {cfg.model_name} - -
-
-
-
- ); - })} -
-
- )} - - {isLoading && ( -
-
-
- {["skeleton-a", "skeleton-b", "skeleton-c"].map((key) => ( - - - - - - - - ))} -
-
-
- )} - - {!isLoading && ( -
- {(userConfigs?.length ?? 0) === 0 ? ( - - -

No Vision Models Yet

-

- {canCreate - ? "Add your own vision-capable model (GPT-4o, Claude, Gemini, etc.)" - : "No vision models have been added to this space yet. Contact a space owner to add one."} -

-
-
- ) : ( -
- {userConfigs?.map((config) => { - const member = config.user_id ? memberMap.get(config.user_id) : null; - - return ( -
- - - {/* Header: Icon + Name + Actions */} -
-
-
- {getProviderIcon(config.provider, { className: "size-4" })} -
-
-

- {config.name} -

- {config.description && ( -

- {config.description} -

- )} -
-
- {(canUpdate || canDelete) && ( -
- {canUpdate && ( - - - - - - Edit - - - )} - {canDelete && ( - - - - - - Delete - - - )} -
- )} -
- - {/* Footer: Date + Creator */} -
- -
- - {new Date(config.created_at).toLocaleDateString(undefined, { - year: "numeric", - month: "short", - day: "numeric", - })} - - {member && ( - <> - - - - -
- - {member.avatarUrl && ( - - )} - - {getInitials(member.name)} - - - - {member.name} - -
-
- - {member.email || member.name} - -
-
- - )} -
-
-
-
-
- ); - })} -
- )} -
- )} - - { - setIsDialogOpen(open); - if (!open) setEditingConfig(null); - }} - config={editingConfig} - isGlobal={false} - searchSpaceId={searchSpaceId} - mode={editingConfig ? "edit" : "create"} - /> - - !open && setConfigToDelete(null)} - > - - - Delete Vision Model - - Are you sure you want to delete{" "} - {configToDelete?.name}? - - - - Cancel - - Delete - {isDeleting && } - - - - -
- ); -} diff --git a/surfsense_web/components/shared/image-config-dialog.tsx b/surfsense_web/components/shared/image-config-dialog.tsx deleted file mode 100644 index 36d16081a..000000000 --- a/surfsense_web/components/shared/image-config-dialog.tsx +++ /dev/null @@ -1,456 +0,0 @@ -"use client"; - -import { useAtomValue } from "jotai"; -import { AlertCircle, Check, ChevronsUpDown } from "lucide-react"; -import { useCallback, useEffect, useMemo, useRef, useState } from "react"; -import { toast } from "sonner"; -import { - createImageGenConfigMutationAtom, - updateImageGenConfigMutationAtom, -} from "@/atoms/image-gen-config/image-gen-config-mutation.atoms"; -import { updateLLMPreferencesMutationAtom } from "@/atoms/new-llm-config/new-llm-config-mutation.atoms"; -import { Alert, AlertDescription } from "@/components/ui/alert"; -import { Badge } from "@/components/ui/badge"; -import { Button } from "@/components/ui/button"; -import { - Command, - CommandEmpty, - CommandGroup, - CommandInput, - CommandItem, - CommandList, -} from "@/components/ui/command"; -import { Dialog, DialogContent, DialogTitle } from "@/components/ui/dialog"; -import { Input } from "@/components/ui/input"; -import { Label } from "@/components/ui/label"; -import { Popover, PopoverContent, PopoverTrigger } from "@/components/ui/popover"; -import { - Select, - SelectContent, - SelectItem, - SelectTrigger, - SelectValue, -} from "@/components/ui/select"; -import { Separator } from "@/components/ui/separator"; -import { Spinner } from "@/components/ui/spinner"; -import { IMAGE_GEN_MODELS, IMAGE_GEN_PROVIDERS } from "@/contracts/enums/image-gen-providers"; -import type { - GlobalImageGenConfig, - ImageGenerationConfig, - ImageGenProvider, -} from "@/contracts/types/new-llm-config.types"; -import { cn } from "@/lib/utils"; - -interface ImageConfigDialogProps { - open: boolean; - onOpenChange: (open: boolean) => void; - config: ImageGenerationConfig | GlobalImageGenConfig | null; - isGlobal: boolean; - searchSpaceId: number; - mode: "create" | "edit" | "view"; - defaultProvider?: string; -} - -const INITIAL_FORM = { - name: "", - description: "", - provider: "", - model_name: "", - api_key: "", - api_base: "", - api_version: "", -}; - -export function ImageConfigDialog({ - open, - onOpenChange, - config, - isGlobal, - searchSpaceId, - mode, - defaultProvider, -}: ImageConfigDialogProps) { - const [isSubmitting, setIsSubmitting] = useState(false); - const [formData, setFormData] = useState(INITIAL_FORM); - const [modelComboboxOpen, setModelComboboxOpen] = useState(false); - const [scrollPos, setScrollPos] = useState<"top" | "middle" | "bottom">("top"); - const scrollRef = useRef(null); - - useEffect(() => { - if (open) { - if (mode === "edit" && config && !isGlobal) { - setFormData({ - name: config.name || "", - description: config.description || "", - provider: config.provider || "", - model_name: config.model_name || "", - api_key: (config as ImageGenerationConfig).api_key || "", - api_base: config.api_base || "", - api_version: config.api_version || "", - }); - } else if (mode === "create") { - setFormData({ ...INITIAL_FORM, provider: defaultProvider ?? "" }); - } - setScrollPos("top"); - } - }, [open, mode, config, isGlobal, defaultProvider]); - - const { mutateAsync: createConfig } = useAtomValue(createImageGenConfigMutationAtom); - const { mutateAsync: updateConfig } = useAtomValue(updateImageGenConfigMutationAtom); - const { mutateAsync: updatePreferences } = useAtomValue(updateLLMPreferencesMutationAtom); - - const handleScroll = useCallback((e: React.UIEvent) => { - const el = e.currentTarget; - const atTop = el.scrollTop <= 2; - const atBottom = el.scrollHeight - el.scrollTop - el.clientHeight <= 2; - setScrollPos(atTop ? "top" : atBottom ? "bottom" : "middle"); - }, []); - - const suggestedModels = useMemo(() => { - if (!formData.provider) return []; - return IMAGE_GEN_MODELS.filter((m) => m.provider === formData.provider); - }, [formData.provider]); - - const getTitle = () => { - if (mode === "create") return "Add Image Model"; - if (isGlobal) return "View Global Image Model"; - return "Edit Image Model"; - }; - - const getSubtitle = () => { - if (mode === "create") return "Set up a new image generation provider"; - if (isGlobal) return "Read-only global configuration"; - return "Update your image model settings"; - }; - - const handleSubmit = useCallback(async () => { - setIsSubmitting(true); - try { - if (mode === "create") { - const result = await createConfig({ - name: formData.name, - provider: formData.provider as ImageGenProvider, - model_name: formData.model_name, - api_key: formData.api_key, - api_base: formData.api_base || undefined, - api_version: formData.api_version || undefined, - description: formData.description || undefined, - search_space_id: searchSpaceId, - }); - if (result?.id) { - await updatePreferences({ - search_space_id: searchSpaceId, - data: { image_generation_config_id: result.id }, - }); - } - onOpenChange(false); - } else if (!isGlobal && config) { - await updateConfig({ - id: config.id, - data: { - name: formData.name, - description: formData.description || undefined, - provider: formData.provider as ImageGenProvider, - model_name: formData.model_name, - api_key: formData.api_key, - api_base: formData.api_base || undefined, - api_version: formData.api_version || undefined, - }, - }); - onOpenChange(false); - } - } catch (error) { - console.error("Failed to save image config:", error); - toast.error("Failed to save image model"); - } finally { - setIsSubmitting(false); - } - }, [ - mode, - isGlobal, - config, - formData, - searchSpaceId, - createConfig, - updateConfig, - updatePreferences, - onOpenChange, - ]); - - const handleUseGlobalConfig = useCallback(async () => { - if (!config || !isGlobal) return; - setIsSubmitting(true); - try { - await updatePreferences({ - search_space_id: searchSpaceId, - data: { image_generation_config_id: config.id }, - }); - toast.success(`Now using ${config.name}`); - onOpenChange(false); - } catch (error) { - console.error("Failed to set image model:", error); - toast.error("Failed to set image model"); - } finally { - setIsSubmitting(false); - } - }, [config, isGlobal, searchSpaceId, updatePreferences, onOpenChange]); - - const isFormValid = formData.name && formData.provider && formData.model_name && formData.api_key; - const selectedProvider = IMAGE_GEN_PROVIDERS.find((p) => p.value === formData.provider); - - return ( - - e.preventDefault()} - > - {getTitle()} - - {/* Header */} -
-
-
-

{getTitle()}

- {isGlobal && mode !== "create" && ( - - Global - - )} -
-

{getSubtitle()}

- {config && mode !== "create" && ( -

{config.model_name}

- )} -
-
- - {/* Scrollable content */} -
- {isGlobal && config && ( - <> - - - - Global configurations are read-only. To customize, create a new model. - - -
-
-
-
- Name -
-

{config.name}

-
- {config.description && ( -
-
- Description -
-

{config.description}

-
- )} -
- -
-
-
- Provider -
-

{config.provider}

-
-
-
- Model -
-

{config.model_name}

-
-
-
- - )} - - {(mode === "create" || (mode === "edit" && !isGlobal)) && ( -
-
- - setFormData((p) => ({ ...p, name: e.target.value }))} - /> -
- -
- - setFormData((p) => ({ ...p, description: e.target.value }))} - /> -
- - - -
- - -
- -
- - {suggestedModels.length > 0 ? ( - - - - - - - setFormData((p) => ({ ...p, model_name: val }))} - /> - - - - Type a custom model name - - - - {suggestedModels.map((m) => ( - { - setFormData((p) => ({ ...p, model_name: m.value })); - setModelComboboxOpen(false); - }} - > - - {m.value} - - {m.label} - - - ))} - - - - - - ) : ( - setFormData((p) => ({ ...p, model_name: e.target.value }))} - /> - )} -
- -
- - setFormData((p) => ({ ...p, api_key: e.target.value }))} - /> -
- -
- - setFormData((p) => ({ ...p, api_base: e.target.value }))} - /> -
- - {formData.provider === "AZURE_OPENAI" && ( -
- - setFormData((p) => ({ ...p, api_version: e.target.value }))} - /> -
- )} -
- )} -
- - {/* Fixed footer */} -
- - {mode === "create" || (mode === "edit" && !isGlobal) ? ( - - ) : isGlobal && config ? ( - - ) : null} -
-
-
- ); -} diff --git a/surfsense_web/components/shared/llm-config-form.tsx b/surfsense_web/components/shared/llm-config-form.tsx deleted file mode 100644 index 06de4129b..000000000 --- a/surfsense_web/components/shared/llm-config-form.tsx +++ /dev/null @@ -1,527 +0,0 @@ -"use client"; - -import { zodResolver } from "@hookform/resolvers/zod"; -import { useAtomValue } from "jotai"; -import { Check, ChevronDown, ChevronsUpDown } from "lucide-react"; -import { useEffect, useMemo, useState } from "react"; -import { type Resolver, useForm } from "react-hook-form"; -import { z } from "zod"; -import { - defaultSystemInstructionsAtom, - modelListAtom, -} from "@/atoms/new-llm-config/new-llm-config-query.atoms"; -import { Badge } from "@/components/ui/badge"; -import { Button } from "@/components/ui/button"; -import { Collapsible, CollapsibleContent, CollapsibleTrigger } from "@/components/ui/collapsible"; -import { - Command, - CommandEmpty, - CommandGroup, - CommandInput, - CommandItem, - CommandList, -} from "@/components/ui/command"; -import { - Form, - FormControl, - FormDescription, - FormField, - FormItem, - FormLabel, - FormMessage, -} from "@/components/ui/form"; -import { Input } from "@/components/ui/input"; -import { Popover, PopoverContent, PopoverTrigger } from "@/components/ui/popover"; -import { - Select, - SelectContent, - SelectItem, - SelectTrigger, - SelectValue, -} from "@/components/ui/select"; -import { Separator } from "@/components/ui/separator"; -import { Switch } from "@/components/ui/switch"; -import { Textarea } from "@/components/ui/textarea"; -import { LLM_PROVIDERS } from "@/contracts/enums/llm-providers"; -import type { CreateNewLLMConfigRequest } from "@/contracts/types/new-llm-config.types"; -import { cn } from "@/lib/utils"; -import InferenceParamsEditor from "../inference-params-editor"; - -// Form schema with zod -const formSchema = z.object({ - name: z.string().min(1, "Name is required").max(100), - description: z.string().max(500).optional().nullable(), - provider: z.string().min(1, "Provider is required"), - custom_provider: z.string().max(100).optional().nullable(), - model_name: z.string().min(1, "Model name is required").max(100), - api_key: z.string().min(1, "API key is required"), - api_base: z.string().max(500).optional().nullable(), - litellm_params: z.record(z.string(), z.any()).optional().nullable(), - system_instructions: z.string().default(""), - use_default_system_instructions: z.boolean().default(true), - citations_enabled: z.boolean().default(true), - search_space_id: z.number(), -}); - -type FormValues = z.infer; - -export type LLMConfigFormData = CreateNewLLMConfigRequest; - -interface LLMConfigFormProps { - initialData?: Partial; - searchSpaceId: number; - onSubmit: (data: LLMConfigFormData) => Promise; - mode?: "create" | "edit"; - showAdvanced?: boolean; - formId?: string; -} - -export function LLMConfigForm({ - initialData, - searchSpaceId, - onSubmit, - mode = "create", - showAdvanced = true, - formId, -}: LLMConfigFormProps) { - const { data: defaultInstructions, isSuccess: defaultInstructionsLoaded } = useAtomValue( - defaultSystemInstructionsAtom - ); - const { data: dynamicModels } = useAtomValue(modelListAtom); - const [modelComboboxOpen, setModelComboboxOpen] = useState(false); - const [advancedOpen, setAdvancedOpen] = useState(false); - const [systemInstructionsOpen, setSystemInstructionsOpen] = useState(false); - - const form = useForm({ - resolver: zodResolver(formSchema) as Resolver, - defaultValues: { - name: initialData?.name ?? "", - description: initialData?.description ?? "", - provider: initialData?.provider ?? "", - custom_provider: initialData?.custom_provider ?? "", - model_name: initialData?.model_name ?? "", - api_key: initialData?.api_key ?? "", - api_base: initialData?.api_base ?? "", - litellm_params: initialData?.litellm_params ?? {}, - system_instructions: initialData?.system_instructions ?? "", - use_default_system_instructions: initialData?.use_default_system_instructions ?? true, - citations_enabled: initialData?.citations_enabled ?? true, - search_space_id: searchSpaceId, - }, - }); - - // Load default instructions when available (only for new configs) - useEffect(() => { - if ( - mode === "create" && - defaultInstructionsLoaded && - defaultInstructions?.default_system_instructions && - !form.getValues("system_instructions") - ) { - form.setValue("system_instructions", defaultInstructions.default_system_instructions); - } - }, [defaultInstructionsLoaded, defaultInstructions, mode, form]); - - const watchProvider = form.watch("provider"); - const selectedProvider = LLM_PROVIDERS.find((p) => p.value === watchProvider); - const availableModels = useMemo( - () => (dynamicModels ?? []).filter((m) => m.provider === watchProvider), - [dynamicModels, watchProvider] - ); - - const handleProviderChange = (value: string) => { - form.setValue("provider", value); - form.setValue("model_name", ""); - - // Auto-fill API base for certain providers - const provider = LLM_PROVIDERS.find((p) => p.value === value); - if (provider?.apiBase) { - form.setValue("api_base", provider.apiBase); - } - }; - - const handleFormSubmit = async (values: FormValues) => { - await onSubmit(values as LLMConfigFormData); - }; - - return ( -
- - {/* Model Configuration Section */} -
-
- Model Configuration -
- - {/* Name & Description */} -
- ( - - Configuration Name - - - - - - )} - /> - - ( - - - Description - - Optional - - - - - - - - )} - /> -
- - {/* Provider Selection */} - ( - - LLM Provider - - - - )} - /> - - {/* Custom Provider (conditional) */} - {watchProvider === "CUSTOM" && ( - ( - - Custom Provider Name - - - - - - )} - /> - )} - - {/* Model Name with Combobox */} - ( - - Model Name - - - - - - - - - - - -
- {field.value ? `Using: "${field.value}"` : "Type your model name"} -
-
- {availableModels.length > 0 && ( - - {availableModels - .filter( - (model) => - !field.value || - model.value.toLowerCase().includes(field.value.toLowerCase()) || - model.label.toLowerCase().includes(field.value.toLowerCase()) - ) - .slice(0, 50) - .map((model) => ( - { - field.onChange(value); - setModelComboboxOpen(false); - }} - className="py-2" - > - -
-
{model.label}
- {model.contextWindow && ( -
- Context: {model.contextWindow} -
- )} -
-
- ))} -
- )} -
-
-
-
- {selectedProvider?.example && ( - - Example: {selectedProvider.example} - - )} - -
- )} - /> - - {/* API Credentials */} -
- ( - - API Key - - - - {watchProvider === "OLLAMA" && ( - - Ollama doesn't require auth — enter any value - - )} - - - )} - /> - - ( - - - API Base URL - {selectedProvider?.apiBase && ( - - Auto-filled - - )} - - - - - - - )} - /> -
- - {/* Ollama Quick Actions */} - {watchProvider === "OLLAMA" && ( -
- - -
- )} -
- - {/* Advanced Parameters */} - {showAdvanced && ( - <> - - - - - - - ( - - - - - - - )} - /> - - - - )} - - {/* System Instructions & Citations Section */} - - - - - - - {/* System Instructions */} - ( - -
- Instructions for the AI - {defaultInstructions && ( - - )} -
- -