diff --git a/surfsense_backend/alembic/versions/156_add_model_connections.py b/surfsense_backend/alembic/versions/156_add_model_connections.py new file mode 100644 index 000000000..0a11d7f9d --- /dev/null +++ b/surfsense_backend/alembic/versions/156_add_model_connections.py @@ -0,0 +1,178 @@ +"""add model connections + +Revision ID: 156 +Revises: 155 +""" + +from collections.abc import Sequence + +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +from alembic import op + +revision: str = "156" +down_revision: str | None = "155" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +connection_protocol = postgresql.ENUM( + "OLLAMA", + "OPENAI_COMPATIBLE", + "NATIVE", + name="connectionprotocol", + create_type=False, +) +connection_scope = postgresql.ENUM( + "GLOBAL", + "SEARCH_SPACE", + "USER", + name="connectionscope", + create_type=False, +) +model_source = postgresql.ENUM( + "DISCOVERED", + "MANUAL", + name="modelsource", + create_type=False, +) + + +def upgrade() -> None: + bind = op.get_bind() + connection_protocol.create(bind, checkfirst=True) + connection_scope.create(bind, checkfirst=True) + model_source.create(bind, checkfirst=True) + + op.create_table( + "connections", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("created_at", sa.DateTime(timezone=True), nullable=False), + sa.Column("protocol", connection_protocol, nullable=False), + sa.Column("native_provider", sa.String(length=100), nullable=True), + sa.Column("base_url", sa.String(length=500), nullable=True), + sa.Column("api_key", sa.String(), nullable=True), + sa.Column( + "extra", + postgresql.JSONB(astext_type=sa.Text()), + server_default=sa.text("'{}'::jsonb"), + nullable=False, + ), + sa.Column("scope", connection_scope, nullable=False), + sa.Column("enabled", sa.Boolean(), server_default=sa.text("true"), nullable=False), + sa.Column("search_space_id", sa.Integer(), nullable=True), + sa.Column("user_id", sa.UUID(), nullable=True), + sa.Column("last_verified_at", sa.TIMESTAMP(timezone=True), nullable=True), + sa.Column("last_status", sa.String(length=50), nullable=True), + sa.Column("last_error", sa.Text(), nullable=True), + sa.CheckConstraint( + "(scope = 'GLOBAL' AND search_space_id IS NULL AND user_id IS NULL) OR " + "(scope = 'SEARCH_SPACE' AND search_space_id IS NOT NULL AND user_id IS NOT NULL) OR " + "(scope = 'USER' AND user_id IS NOT NULL)", + name="ck_connections_scope_owner", + ), + sa.ForeignKeyConstraint( + ["search_space_id"], ["searchspaces.id"], ondelete="CASCADE" + ), + sa.ForeignKeyConstraint(["user_id"], ["user.id"], ondelete="CASCADE"), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index(op.f("ix_connections_protocol"), "connections", ["protocol"], unique=False) + op.create_index( + op.f("ix_connections_native_provider"), + "connections", + ["native_provider"], + unique=False, + ) + op.create_index(op.f("ix_connections_scope"), "connections", ["scope"], unique=False) + + op.create_table( + "models", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("created_at", sa.DateTime(timezone=True), nullable=False), + sa.Column("connection_id", sa.Integer(), nullable=False), + sa.Column("model_id", sa.String(length=255), nullable=False), + sa.Column("display_name", sa.String(length=255), nullable=True), + sa.Column( + "source", + model_source, + server_default="DISCOVERED", + nullable=False, + ), + sa.Column( + "capabilities", + postgresql.JSONB(astext_type=sa.Text()), + server_default=sa.text("'{}'::jsonb"), + nullable=False, + ), + sa.Column( + "capabilities_declared", + postgresql.JSONB(astext_type=sa.Text()), + server_default=sa.text("'{}'::jsonb"), + nullable=False, + ), + sa.Column( + "capabilities_verified", + postgresql.JSONB(astext_type=sa.Text()), + server_default=sa.text("'{}'::jsonb"), + nullable=False, + ), + sa.Column( + "capabilities_override", + postgresql.JSONB(astext_type=sa.Text()), + server_default=sa.text("'{}'::jsonb"), + nullable=False, + ), + sa.Column("embedding_dimension", sa.Integer(), nullable=True), + sa.Column("enabled", sa.Boolean(), server_default=sa.text("true"), nullable=False), + sa.Column("billing_tier", sa.String(length=50), nullable=True), + sa.Column( + "catalog", + postgresql.JSONB(astext_type=sa.Text()), + server_default=sa.text("'{}'::jsonb"), + nullable=False, + ), + sa.ForeignKeyConstraint(["connection_id"], ["connections.id"], ondelete="CASCADE"), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint( + "connection_id", "model_id", name="uq_models_connection_model_id" + ), + ) + op.create_index(op.f("ix_models_connection_id"), "models", ["connection_id"], unique=False) + op.create_index("ix_models_model_id", "models", ["model_id"], unique=False) + op.create_index(op.f("ix_models_billing_tier"), "models", ["billing_tier"], unique=False) + + op.add_column( + "searchspaces", + sa.Column("chat_model_id", sa.Integer(), nullable=True), + ) + op.add_column( + "searchspaces", + sa.Column("image_gen_model_id", sa.Integer(), nullable=True), + ) + op.add_column( + "searchspaces", + sa.Column("vision_model_id", sa.Integer(), nullable=True), + ) + + +def downgrade() -> None: + op.drop_column("searchspaces", "vision_model_id") + op.drop_column("searchspaces", "image_gen_model_id") + op.drop_column("searchspaces", "chat_model_id") + + op.drop_index(op.f("ix_models_billing_tier"), table_name="models") + op.drop_index("ix_models_model_id", table_name="models") + op.drop_index(op.f("ix_models_connection_id"), table_name="models") + op.drop_table("models") + + op.drop_index(op.f("ix_connections_scope"), table_name="connections") + op.drop_index(op.f("ix_connections_native_provider"), table_name="connections") + op.drop_index(op.f("ix_connections_protocol"), table_name="connections") + op.drop_table("connections") + + bind = op.get_bind() + model_source.drop(bind, checkfirst=True) + connection_scope.drop(bind, checkfirst=True) + connection_protocol.drop(bind, checkfirst=True) diff --git a/surfsense_backend/app/db.py b/surfsense_backend/app/db.py index 6117caecb..9756cb32f 100644 --- a/surfsense_backend/app/db.py +++ b/surfsense_backend/app/db.py @@ -280,6 +280,23 @@ class VisionProvider(StrEnum): CUSTOM = "CUSTOM" +class ConnectionProtocol(StrEnum): + OLLAMA = "OLLAMA" + OPENAI_COMPATIBLE = "OPENAI_COMPATIBLE" + NATIVE = "NATIVE" + + +class ConnectionScope(StrEnum): + GLOBAL = "GLOBAL" + SEARCH_SPACE = "SEARCH_SPACE" + USER = "USER" + + +class ModelSource(StrEnum): + DISCOVERED = "DISCOVERED" + MANUAL = "MANUAL" + + class LogLevel(StrEnum): DEBUG = "DEBUG" INFO = "INFO" @@ -1642,6 +1659,79 @@ class Report(BaseModel, TimestampMixin): thread = relationship("NewChatThread") +class Connection(BaseModel, TimestampMixin): + __tablename__ = "connections" + + protocol = Column(SQLAlchemyEnum(ConnectionProtocol), nullable=False, index=True) + native_provider = Column(String(100), nullable=True, index=True) + base_url = Column(String(500), nullable=True) + api_key = Column(String, nullable=True) + extra = Column(JSONB, nullable=False, default=dict, server_default="{}") + scope = Column(SQLAlchemyEnum(ConnectionScope), nullable=False, index=True) + enabled = Column(Boolean, nullable=False, default=True, server_default="true") + + search_space_id = Column( + Integer, ForeignKey("searchspaces.id", ondelete="CASCADE"), nullable=True + ) + user_id = Column( + UUID(as_uuid=True), ForeignKey("user.id", ondelete="CASCADE"), nullable=True + ) + + last_verified_at = Column(TIMESTAMP(timezone=True), nullable=True) + last_status = Column(String(50), nullable=True) + last_error = Column(Text, nullable=True) + + search_space = relationship("SearchSpace", back_populates="connections") + user = relationship("User", back_populates="connections") + models = relationship( + "Model", + back_populates="connection", + order_by="Model.id", + cascade="all, delete-orphan", + passive_deletes=True, + ) + + __table_args__ = ( + CheckConstraint( + "(scope = 'GLOBAL' AND search_space_id IS NULL AND user_id IS NULL) OR " + "(scope = 'SEARCH_SPACE' AND search_space_id IS NOT NULL AND user_id IS NOT NULL) OR " + "(scope = 'USER' AND user_id IS NOT NULL)", + name="ck_connections_scope_owner", + ), + ) + + +class Model(BaseModel, TimestampMixin): + __tablename__ = "models" + + connection_id = Column( + Integer, ForeignKey("connections.id", ondelete="CASCADE"), nullable=False, index=True + ) + model_id = Column(String(255), nullable=False) + display_name = Column(String(255), nullable=True) + source = Column( + SQLAlchemyEnum(ModelSource), + nullable=False, + default=ModelSource.DISCOVERED, + server_default=ModelSource.DISCOVERED.value, + ) + capabilities = Column(JSONB, nullable=False, default=dict, server_default="{}") + capabilities_declared = Column(JSONB, nullable=False, default=dict, server_default="{}") + capabilities_verified = Column(JSONB, nullable=False, default=dict, server_default="{}") + capabilities_override = Column(JSONB, nullable=False, default=dict, server_default="{}") + embedding_dimension = Column(Integer, nullable=True) + enabled = Column(Boolean, nullable=False, default=True, server_default="true") + billing_tier = Column(String(50), nullable=True, index=True) + catalog = Column(JSONB, nullable=False, default=dict, server_default="{}") + + connection = relationship("Connection", back_populates="models") + + __table_args__ = ( + UniqueConstraint("connection_id", "model_id", name="uq_models_connection_model_id"), + Index("ix_models_model_id", "model_id"), + ) + + class ImageGenerationConfig(BaseModel, TimestampMixin): """ Dedicated configuration table for image generation models. @@ -1794,7 +1884,7 @@ class SearchSpace(BaseModel, TimestampMixin): # - Positive IDs: Custom configs from DB (NewLLMConfig table) agent_llm_id = Column( Integer, nullable=True, default=0 - ) # For agent/chat operations, defaults to Auto mode + ) # For chat operations, defaults to Auto mode image_generation_config_id = Column( Integer, nullable=True, default=0 ) # For image generation, defaults to Auto mode @@ -1802,6 +1892,22 @@ class SearchSpace(BaseModel, TimestampMixin): Integer, nullable=True, default=0 ) # For vision/screenshot analysis, defaults to Auto mode + # New connection/model role bindings. These supersede the legacy config + # columns above without removing them in this PR. + # Note: ID values preserve the existing convention: + # - 0: Auto mode + # - Negative IDs: Global virtual models from global_llm_config.yaml + # - Positive IDs: User/search-space models from the models table + chat_model_id = Column( + Integer, nullable=True, default=0 + ) # For agent/chat operations, defaults to Auto mode + image_gen_model_id = Column( + Integer, nullable=True, default=0 + ) # For image generation, defaults to Auto mode when eligible + vision_model_id = Column( + Integer, nullable=True, default=0 + ) # For vision/screenshot analysis, defaults to Auto mode + ai_file_sort_enabled = Column( Boolean, nullable=False, default=False, server_default="false" ) @@ -1889,6 +1995,13 @@ class SearchSpace(BaseModel, TimestampMixin): order_by="VisionLLMConfig.id", cascade="all, delete-orphan", ) + connections = relationship( + "Connection", + back_populates="search_space", + order_by="Connection.id", + cascade="all, delete-orphan", + passive_deletes=True, + ) automations = relationship( "Automation", @@ -2429,6 +2542,11 @@ if config.AUTH_TYPE == "GOOGLE": back_populates="user", passive_deletes=True, ) + connections = relationship( + "Connection", + back_populates="user", + passive_deletes=True, + ) # Automations created by this user automations = relationship( @@ -2568,6 +2686,11 @@ else: back_populates="user", passive_deletes=True, ) + connections = relationship( + "Connection", + back_populates="user", + passive_deletes=True, + ) # Automations created by this user automations = relationship( diff --git a/surfsense_backend/app/routes/__init__.py b/surfsense_backend/app/routes/__init__.py index 5cc029884..244208550 100644 --- a/surfsense_backend/app/routes/__init__.py +++ b/surfsense_backend/app/routes/__init__.py @@ -42,6 +42,7 @@ from .linear_add_connector_route import router as linear_add_connector_router from .logs_routes import router as logs_router from .luma_add_connector_route import router as luma_add_connector_router from .mcp_oauth_route import router as mcp_oauth_router +from .model_connections_routes import router as model_connections_router from .memory_routes import router as memory_router from .model_list_routes import router as model_list_router from .new_chat_routes import router as new_chat_router @@ -117,6 +118,7 @@ router.include_router(confluence_add_connector_router) router.include_router(clickup_add_connector_router) router.include_router(dropbox_add_connector_router) router.include_router(new_llm_config_router) # LLM configs with prompt configuration +router.include_router(model_connections_router) # Connection-centric model catalog router.include_router(model_list_router) # Dynamic model catalogue from OpenRouter router.include_router(logs_router) router.include_router(circleback_webhook_router) # Circleback meeting webhooks diff --git a/surfsense_backend/app/routes/model_connections_routes.py b/surfsense_backend/app/routes/model_connections_routes.py new file mode 100644 index 000000000..69910183d --- /dev/null +++ b/surfsense_backend/app/routes/model_connections_routes.py @@ -0,0 +1,346 @@ +import logging + +from fastapi import APIRouter, Depends, HTTPException +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import selectinload + +from app.config import config +from app.db import ( + Connection, + ConnectionScope, + Model, + Permission, + SearchSpace, + User, + get_async_session, +) +from app.schemas import ( + ConnectionCreate, + ConnectionRead, + ConnectionUpdate, + ModelRead, + ModelRolesRead, + ModelRolesUpdate, + ModelUpdate, + VerifyConnectionResponse, +) +from app.services.model_connection_service import ( + discover_models, + persist_verification, + test_model, +) +from app.users import current_active_user +from app.utils.rbac import check_permission + +router = APIRouter() +logger = logging.getLogger(__name__) + + +def _model_read(model: Model | dict) -> ModelRead: + return ModelRead.model_validate(model) + + +def _connection_read(conn: Connection | dict, models: list[Model | dict] | None = None) -> ConnectionRead: + if isinstance(conn, dict): + payload = { + **conn, + "has_api_key": bool(conn.get("api_key")), + "api_key": None, + "models": [_model_read(model) for model in (models or [])], + } + payload.pop("api_key", None) + return ConnectionRead.model_validate(payload) + + return ConnectionRead( + id=conn.id, + protocol=conn.protocol, + native_provider=conn.native_provider, + base_url=conn.base_url, + extra=conn.extra or {}, + scope=conn.scope, + search_space_id=conn.search_space_id, + user_id=conn.user_id, + enabled=conn.enabled, + has_api_key=bool(conn.api_key), + last_verified_at=conn.last_verified_at, + last_status=conn.last_status, + last_error=conn.last_error, + models=[_model_read(model) for model in (models or conn.models or [])], + created_at=conn.created_at, + ) + + +async def _get_search_space(session: AsyncSession, search_space_id: int) -> SearchSpace: + result = await session.execute(select(SearchSpace).where(SearchSpace.id == search_space_id)) + search_space = result.scalars().first() + if not search_space: + raise HTTPException(status_code=404, detail="Search space not found") + return search_space + + +async def _load_connection(session: AsyncSession, connection_id: int) -> Connection: + result = await session.execute( + select(Connection) + .options(selectinload(Connection.models)) + .where(Connection.id == connection_id) + ) + conn = result.scalars().first() + if not conn: + raise HTTPException(status_code=404, detail="Connection not found") + return conn + + +async def _assert_connection_access( + session: AsyncSession, + user: User, + conn: Connection, + permission: str = Permission.LLM_CONFIGS_CREATE.value, +) -> None: + if conn.search_space_id: + await check_permission( + session, + user, + conn.search_space_id, + permission, + "You don't have permission to manage model connections in this search space", + ) + return + if conn.user_id != user.id: + raise HTTPException(status_code=403, detail="Connection does not belong to user") + + +@router.get("/global-model-connections", response_model=list[ConnectionRead]) +async def list_global_connections(user: User = Depends(current_active_user)): + del user + models_by_connection: dict[int, list[dict]] = {} + for model in config.GLOBAL_MODELS: + models_by_connection.setdefault(model["connection_id"], []).append(model) + return [ + _connection_read(conn, models_by_connection.get(conn["id"], [])) + for conn in config.GLOBAL_CONNECTIONS + ] + + +@router.get("/model-connections", response_model=list[ConnectionRead]) +async def list_connections( + search_space_id: int | None = None, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +): + stmt = select(Connection).options(selectinload(Connection.models)) + if search_space_id is not None: + await check_permission( + session, + user, + search_space_id, + Permission.LLM_CONFIGS_CREATE.value, + "You don't have permission to view model connections in this search space", + ) + stmt = stmt.where(Connection.search_space_id == search_space_id) + else: + stmt = stmt.where(Connection.user_id == user.id) + result = await session.execute(stmt.order_by(Connection.id)) + return [_connection_read(conn) for conn in result.scalars().all()] + + +@router.post("/model-connections", response_model=ConnectionRead) +async def create_connection( + data: ConnectionCreate, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +): + if data.scope == ConnectionScope.GLOBAL: + raise HTTPException(status_code=400, detail="GLOBAL connections are YAML-only") + if data.scope == ConnectionScope.SEARCH_SPACE: + if data.search_space_id is None: + raise HTTPException(status_code=400, detail="search_space_id is required") + await check_permission( + session, + user, + data.search_space_id, + Permission.LLM_CONFIGS_CREATE.value, + "You don't have permission to create model connections in this search space", + ) + conn = Connection( + **data.model_dump(exclude={"search_space_id"}), + search_space_id=data.search_space_id if data.scope == ConnectionScope.SEARCH_SPACE else None, + user_id=user.id, + ) + session.add(conn) + await session.commit() + await session.refresh(conn) + return _connection_read(conn) + + +@router.put("/model-connections/{connection_id}", response_model=ConnectionRead) +async def update_connection( + connection_id: int, + data: ConnectionUpdate, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +): + conn = await _load_connection(session, connection_id) + await _assert_connection_access(session, user, conn, Permission.LLM_CONFIGS_UPDATE.value) + for key, value in data.model_dump(exclude_unset=True).items(): + setattr(conn, key, value) + await session.commit() + await session.refresh(conn) + return _connection_read(conn) + + +@router.delete("/model-connections/{connection_id}") +async def delete_connection( + connection_id: int, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +): + conn = await _load_connection(session, connection_id) + await _assert_connection_access(session, user, conn, Permission.LLM_CONFIGS_DELETE.value) + await session.delete(conn) + await session.commit() + return {"status": "deleted"} + + +@router.post("/model-connections/{connection_id}/verify", response_model=VerifyConnectionResponse) +async def verify_model_connection( + connection_id: int, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +): + conn = await _load_connection(session, connection_id) + await _assert_connection_access(session, user, conn, Permission.LLM_CONFIGS_CREATE.value) + result = await persist_verification(conn) + await session.commit() + return VerifyConnectionResponse(status=result.status, ok=result.ok, message=result.message) + + +@router.post("/model-connections/{connection_id}/discover", response_model=list[ModelRead]) +async def discover_connection_models( + connection_id: int, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +): + conn = await _load_connection(session, connection_id) + await _assert_connection_access(session, user, conn, Permission.LLM_CONFIGS_CREATE.value) + discovered = await discover_models(conn) + by_model_id = {model.model_id: model for model in conn.models} + for item in discovered: + db_model = by_model_id.get(item["model_id"]) + if db_model is None: + db_model = Model( + connection_id=conn.id, + model_id=item["model_id"], + display_name=item.get("display_name"), + source=item["source"], + capabilities=item["capabilities"], + capabilities_declared=item["capabilities"], + capabilities_verified={}, + capabilities_override={}, + enabled=False, + catalog=item.get("metadata") or {}, + ) + session.add(db_model) + else: + db_model.display_name = item.get("display_name") or db_model.display_name + db_model.capabilities_declared = item["capabilities"] + db_model.capabilities = { + **item["capabilities"], + **(db_model.capabilities_override or {}), + } + db_model.catalog = item.get("metadata") or db_model.catalog + await session.commit() + await session.refresh(conn) + return [_model_read(model) for model in conn.models] + + +@router.put("/models/{model_id}", response_model=ModelRead) +async def update_model( + model_id: int, + data: ModelUpdate, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +): + result = await session.execute( + select(Model).options(selectinload(Model.connection)).where(Model.id == model_id) + ) + model = result.scalars().first() + if not model: + raise HTTPException(status_code=404, detail="Model not found") + await _assert_connection_access(session, user, model.connection, Permission.LLM_CONFIGS_UPDATE.value) + update = data.model_dump(exclude_unset=True) + for key, value in update.items(): + setattr(model, key, value) + if "capabilities_override" in update: + model.capabilities = { + **(model.capabilities_declared or {}), + **(model.capabilities_override or {}), + } + await session.commit() + await session.refresh(model) + return _model_read(model) + + +@router.post("/models/{model_id}/test", response_model=VerifyConnectionResponse) +async def test_connection_model( + model_id: int, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +): + result = await session.execute( + select(Model).options(selectinload(Model.connection)).where(Model.id == model_id) + ) + model = result.scalars().first() + if not model: + raise HTTPException(status_code=404, detail="Model not found") + await _assert_connection_access(session, user, model.connection, Permission.LLM_CONFIGS_UPDATE.value) + result = await test_model(model.connection, model) + await session.commit() + return VerifyConnectionResponse(status=result.status, ok=result.ok, message=result.message) + + +@router.get("/search-spaces/{search_space_id}/model-roles", response_model=ModelRolesRead) +async def get_model_roles( + search_space_id: int, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +): + await check_permission( + session, + user, + search_space_id, + Permission.LLM_CONFIGS_CREATE.value, + "You don't have permission to view model roles in this search space", + ) + search_space = await _get_search_space(session, search_space_id) + return ModelRolesRead( + chat_model_id=search_space.chat_model_id, + vision_model_id=search_space.vision_model_id, + image_gen_model_id=search_space.image_gen_model_id, + ) + + +@router.put("/search-spaces/{search_space_id}/model-roles", response_model=ModelRolesRead) +async def update_model_roles( + search_space_id: int, + data: ModelRolesUpdate, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +): + await check_permission( + session, + user, + search_space_id, + Permission.LLM_CONFIGS_UPDATE.value, + "You don't have permission to update model roles in this search space", + ) + search_space = await _get_search_space(session, search_space_id) + for key, value in data.model_dump(exclude_unset=True).items(): + setattr(search_space, key, value) + await session.commit() + await session.refresh(search_space) + return ModelRolesRead( + chat_model_id=search_space.chat_model_id, + vision_model_id=search_space.vision_model_id, + image_gen_model_id=search_space.image_gen_model_id, + ) diff --git a/surfsense_backend/app/schemas/__init__.py b/surfsense_backend/app/schemas/__init__.py index fdf34672b..c14671c99 100644 --- a/surfsense_backend/app/schemas/__init__.py +++ b/surfsense_backend/app/schemas/__init__.py @@ -44,6 +44,16 @@ from .image_generation import ( ImageGenerationRead, ) from .logs import LogBase, LogCreate, LogFilter, LogRead, LogUpdate +from .model_connections import ( + ConnectionCreate, + ConnectionRead, + ConnectionUpdate, + ModelRead, + ModelRolesRead, + ModelRolesUpdate, + ModelUpdate, + VerifyConnectionResponse, +) from .new_chat import ( ChatMessage, NewChatMessageAppend, diff --git a/surfsense_backend/app/schemas/model_connections.py b/surfsense_backend/app/schemas/model_connections.py new file mode 100644 index 000000000..731064375 --- /dev/null +++ b/surfsense_backend/app/schemas/model_connections.py @@ -0,0 +1,89 @@ +import uuid +from datetime import datetime +from typing import Any + +from pydantic import BaseModel, ConfigDict, Field + +from app.db import ConnectionProtocol, ConnectionScope, ModelSource + + +class ModelRead(BaseModel): + id: int + connection_id: int + model_id: str + display_name: str | None = None + source: ModelSource | str + capabilities: dict[str, Any] + capabilities_declared: dict[str, Any] = Field(default_factory=dict) + capabilities_verified: dict[str, Any] = Field(default_factory=dict) + capabilities_override: dict[str, Any] = Field(default_factory=dict) + embedding_dimension: int | None = None + enabled: bool + billing_tier: str | None = None + catalog: dict[str, Any] = Field(default_factory=dict) + created_at: datetime | None = None + + model_config = ConfigDict(from_attributes=True) + + +class ConnectionRead(BaseModel): + id: int + protocol: ConnectionProtocol | str + native_provider: str | None = None + base_url: str | None = None + extra: dict[str, Any] = Field(default_factory=dict) + scope: ConnectionScope | str + search_space_id: int | None = None + user_id: uuid.UUID | None = None + enabled: bool + has_api_key: bool + last_verified_at: datetime | None = None + last_status: str | None = None + last_error: str | None = None + models: list[ModelRead] = Field(default_factory=list) + created_at: datetime | None = None + + model_config = ConfigDict(from_attributes=True) + + +class ConnectionCreate(BaseModel): + protocol: ConnectionProtocol + native_provider: str | None = None + base_url: str | None = Field(None, max_length=500) + api_key: str | None = None + extra: dict[str, Any] = Field(default_factory=dict) + scope: ConnectionScope = ConnectionScope.SEARCH_SPACE + search_space_id: int | None = None + enabled: bool = True + + +class ConnectionUpdate(BaseModel): + native_provider: str | None = None + base_url: str | None = Field(None, max_length=500) + api_key: str | None = None + extra: dict[str, Any] | None = None + enabled: bool | None = None + + +class ModelUpdate(BaseModel): + display_name: str | None = Field(None, max_length=255) + enabled: bool | None = None + capabilities_override: dict[str, Any] | None = None + + +class VerifyConnectionResponse(BaseModel): + status: str + ok: bool + message: str = "" + + +class ModelRolesRead(BaseModel): + chat_model_id: int | None = 0 + vision_model_id: int | None = 0 + image_gen_model_id: int | None = 0 + + +class ModelRolesUpdate(BaseModel): + chat_model_id: int | None = None + vision_model_id: int | None = None + image_gen_model_id: int | None = None diff --git a/surfsense_backend/app/services/model_connection_service.py b/surfsense_backend/app/services/model_connection_service.py new file mode 100644 index 000000000..81090acaf --- /dev/null +++ b/surfsense_backend/app/services/model_connection_service.py @@ -0,0 +1,209 @@ +"""Connection verification, model discovery, and capability probing.""" + +from __future__ import annotations + +import contextlib +import logging +from dataclasses import dataclass +from datetime import UTC, datetime +from typing import Any + +import httpx +import litellm + +from app.db import Connection, ConnectionProtocol, Model, ModelSource +from app.services.model_resolver import ensure_v1, to_litellm + +logger = logging.getLogger(__name__) + +VERIFY_TIMEOUT_SECONDS = 8.0 +DISCOVERY_TIMEOUT_SECONDS = 15.0 +TEST_TIMEOUT_SECONDS = 30.0 + + +@dataclass(frozen=True) +class VerifyResult: + status: str + ok: bool + message: str = "" + + +def _auth_headers(conn: Connection) -> dict[str, str]: + if not conn.api_key: + return {} + return {"Authorization": f"Bearer {conn.api_key}"} + + +def _docker_hint(url: str | None, exc_or_status: Any) -> str: + raw = str(exc_or_status) + if not url: + return raw + if "localhost" in url or "127.0.0.1" in url: + return ( + f"{raw}. The backend is running inside Docker; localhost means the " + "backend container. Use host.docker.internal and make sure the model " + "server listens on 0.0.0.0." + ) + if "host.docker.internal" in url and ("refused" in raw.lower() or "connect" in raw.lower()): + return ( + f"{raw}. The host is reachable only if your local model server is " + "listening on 0.0.0.0. On Linux Docker, add " + "`host.docker.internal:host-gateway` to extra_hosts." + ) + return raw + + +async def verify_connection(conn: Connection) -> VerifyResult: + if not conn.base_url and conn.protocol in ( + ConnectionProtocol.OLLAMA, + ConnectionProtocol.OPENAI_COMPATIBLE, + ): + return VerifyResult("UNREACHABLE", False, "Base URL is required.") + + if conn.protocol == ConnectionProtocol.OLLAMA: + url = f"{conn.base_url.rstrip('/')}/api/version" + elif conn.protocol == ConnectionProtocol.OPENAI_COMPATIBLE: + url = f"{ensure_v1(conn.base_url)}/models" + else: + # Native providers do not share one cheap health endpoint. The model + # probe exercises the real path and is the authoritative check. + return VerifyResult("OK", True, "Native provider configuration accepted.") + + try: + async with httpx.AsyncClient(timeout=VERIFY_TIMEOUT_SECONDS) as client: + response = await client.get(url, headers=_auth_headers(conn)) + if response.status_code in (401, 403): + return VerifyResult("AUTH_FAILED", False, "Authentication failed.") + if response.status_code == 404: + if conn.protocol == ConnectionProtocol.OLLAMA and url.endswith("/v1/models"): + message = "Ollama native API should not use /v1." + elif conn.protocol == ConnectionProtocol.OPENAI_COMPATIBLE: + message = "OpenAI-compatible servers should expose /v1/models." + else: + message = "Endpoint returned 404." + return VerifyResult("NOT_FOUND", False, message) + response.raise_for_status() + return VerifyResult("OK", True, "Connection verified.") + except httpx.ConnectError as exc: + return VerifyResult("UNREACHABLE", False, _docker_hint(conn.base_url, exc)) + except httpx.TimeoutException as exc: + return VerifyResult("UNREACHABLE", False, f"Connection timed out: {exc}") + except httpx.HTTPError as exc: + return VerifyResult("UNREACHABLE", False, _docker_hint(conn.base_url, exc)) + + +async def persist_verification(conn: Connection) -> VerifyResult: + result = await verify_connection(conn) + conn.last_verified_at = datetime.now(UTC) + conn.last_status = result.status + conn.last_error = "" if result.ok else result.message + return result + + +def _litellm_capabilities(model_string: str, model_id: str) -> dict[str, bool]: + capabilities = { + "chat": True, + "vision": False, + "tools": False, + "image_gen": False, + "embedding": False, + } + with contextlib.suppress(Exception): + capabilities["vision"] = bool(litellm.supports_vision(model=model_string)) + with contextlib.suppress(Exception): + capabilities["tools"] = bool(litellm.supports_function_calling(model=model_string)) + try: + info = litellm.model_cost.get(model_string) or litellm.model_cost.get(model_id) or {} + mode = str(info.get("mode") or "") + capabilities["embedding"] = mode == "embedding" + capabilities["image_gen"] = mode in {"image_generation", "image_generation_model"} + except Exception: + pass + return capabilities + + +def derive_capabilities(conn: Connection, model_id: str, metadata: dict | None = None) -> dict[str, bool]: + metadata = metadata or {} + if conn.protocol == ConnectionProtocol.OLLAMA: + caps = metadata.get("capabilities") or [] + capabilities = { + "chat": True, + "vision": "vision" in caps, + "tools": False, + "image_gen": False, + "embedding": "embedding" in caps, + } + return capabilities + + model_string, _ = to_litellm(conn, model_id) + return _litellm_capabilities(model_string, model_id) + + +async def discover_models(conn: Connection) -> list[dict[str, Any]]: + if conn.protocol == ConnectionProtocol.OLLAMA: + url = f"{conn.base_url.rstrip('/')}/api/tags" + async with httpx.AsyncClient(timeout=DISCOVERY_TIMEOUT_SECONDS) as client: + response = await client.get(url, headers=_auth_headers(conn)) + response.raise_for_status() + models = response.json().get("models", []) + return [ + { + "model_id": item.get("model") or item.get("name"), + "display_name": item.get("name") or item.get("model"), + "source": ModelSource.DISCOVERED, + "capabilities": derive_capabilities(conn, item.get("model") or item.get("name"), item), + "metadata": item, + } + for item in models + if item.get("model") or item.get("name") + ] + + if conn.protocol == ConnectionProtocol.OPENAI_COMPATIBLE: + url = f"{ensure_v1(conn.base_url)}/models" + async with httpx.AsyncClient(timeout=DISCOVERY_TIMEOUT_SECONDS) as client: + response = await client.get(url, headers=_auth_headers(conn)) + response.raise_for_status() + models = response.json().get("data", []) + return [ + { + "model_id": item.get("id"), + "display_name": item.get("name") or item.get("id"), + "source": ModelSource.DISCOVERED, + "capabilities": derive_capabilities(conn, item.get("id"), item), + "metadata": item, + } + for item in models + if item.get("id") + ] + + # Native providers rely on curated/global catalog entries or manual rows. + return [] + + +async def test_model(conn: Connection, model: Model) -> VerifyResult: + model_string, kwargs = to_litellm(conn, model.model_id) + try: + await litellm.acompletion( + model=model_string, + messages=[{"role": "user", "content": "Hello"}], + timeout=TEST_TIMEOUT_SECONDS, + **kwargs, + ) + except Exception as exc: + return VerifyResult("UNREACHABLE", False, str(exc)) + + model.capabilities_verified = { + **(model.capabilities_verified or {}), + "chat": True, + } + return VerifyResult("OK", True, "Model test succeeded.") + + +__all__ = [ + "VerifyResult", + "derive_capabilities", + "discover_models", + "persist_verification", + "test_model", + "verify_connection", +] diff --git a/surfsense_backend/tests/unit/services/test_model_connections.py b/surfsense_backend/tests/unit/services/test_model_connections.py new file mode 100644 index 000000000..98042501b --- /dev/null +++ b/surfsense_backend/tests/unit/services/test_model_connections.py @@ -0,0 +1,75 @@ +from app.services.global_model_catalog import materialize_global_model_catalog +from app.services.model_resolver import ensure_v1, to_litellm + + +def test_openai_compatible_resolver_normalizes_v1() -> None: + model, kwargs = to_litellm( + { + "protocol": "OPENAI_COMPATIBLE", + "base_url": "http://host.docker.internal:1234", + "api_key": "local-key", + "extra": {}, + }, + "qwen/qwen3", + ) + + assert model == "openai/qwen/qwen3" + assert kwargs["api_base"] == "http://host.docker.internal:1234/v1" + assert kwargs["api_key"] == "local-key" + assert ensure_v1("http://example.com/v1") == "http://example.com/v1" + + +def test_ollama_resolver_uses_native_api_base() -> None: + model, kwargs = to_litellm( + { + "protocol": "OLLAMA", + "base_url": "http://host.docker.internal:11434", + "api_key": None, + "extra": {}, + }, + "llama3.2", + ) + + assert model == "ollama_chat/llama3.2" + assert kwargs["api_base"] == "http://host.docker.internal:11434" + + +def test_global_materialization_preserves_tier_and_keeps_key_server_side() -> None: + connections, models = materialize_global_model_catalog( + chat_configs=[ + { + "id": -101, + "name": "OpenRouter Free", + "provider": "OPENROUTER", + "model_name": "meta-llama/llama-3.1-8b-instruct:free", + "api_key": "sk-global-secret", + "billing_tier": "free", + "anonymous_enabled": True, + "seo_enabled": True, + "rpm": 10, + "tpm": 1000, + }, + { + "id": -102, + "name": "OpenRouter Premium", + "provider": "OPENROUTER", + "model_name": "anthropic/claude-sonnet-4", + "api_key": "sk-global-secret", + "billing_tier": "premium", + }, + ], + vision_configs=[], + image_configs=[], + ) + + assert len(connections) == 1 + assert connections[0]["api_key"] == "sk-global-secret" + assert {model["billing_tier"] for model in models} == {"free", "premium"} + assert models[0]["catalog"]["anonymous_enabled"] is True + assert models[0]["catalog"]["rpm"] == 10 + + public_connections = [ + {key: value for key, value in connection.items() if key != "api_key"} + for connection in connections + ] + assert "sk-" not in repr(public_connections)