From 780e24213240310d52071c89c528ad1643d86071 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Thu, 11 Jun 2026 00:11:53 +0530 Subject: [PATCH] feat(model-connections): implement manual model addition and enhance model discovery --- .../app/routes/model_connections_routes.py | 38 +++ surfsense_backend/app/schemas/__init__.py | 1 + .../app/schemas/model_connections.py | 11 + .../app/services/model_connection_service.py | 28 +- .../model-connections-mutation.atoms.ts | 28 +- .../settings/model-connections-settings.tsx | 312 ++++++++++++------ .../types/model-connections.types.ts | 6 + .../lib/apis/model-connections-api.service.ts | 12 + 8 files changed, 335 insertions(+), 101 deletions(-) diff --git a/surfsense_backend/app/routes/model_connections_routes.py b/surfsense_backend/app/routes/model_connections_routes.py index 69910183d..6d19a5ed1 100644 --- a/surfsense_backend/app/routes/model_connections_routes.py +++ b/surfsense_backend/app/routes/model_connections_routes.py @@ -10,6 +10,7 @@ from app.db import ( Connection, ConnectionScope, Model, + ModelSource, Permission, SearchSpace, User, @@ -19,6 +20,7 @@ from app.schemas import ( ConnectionCreate, ConnectionRead, ConnectionUpdate, + ModelCreate, ModelRead, ModelRolesRead, ModelRolesUpdate, @@ -26,6 +28,7 @@ from app.schemas import ( VerifyConnectionResponse, ) from app.services.model_connection_service import ( + derive_capabilities, discover_models, persist_verification, test_model, @@ -254,6 +257,41 @@ async def discover_connection_models( return [_model_read(model) for model in conn.models] +@router.post("/model-connections/{connection_id}/models", response_model=ModelRead) +async def add_manual_model( + connection_id: int, + data: ModelCreate, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +): + conn = await _load_connection(session, connection_id) + await _assert_connection_access(session, user, conn, Permission.LLM_CONFIGS_UPDATE.value) + + model_id = data.model_id.strip() + if not model_id: + raise HTTPException(status_code=400, detail="model_id is required") + if any(existing.model_id == model_id for existing in conn.models): + raise HTTPException(status_code=400, detail="Model already exists on this connection") + + capabilities = derive_capabilities(conn, model_id) + model = Model( + connection_id=conn.id, + model_id=model_id, + display_name=data.display_name or None, + source=ModelSource.MANUAL, + capabilities=capabilities, + capabilities_declared=capabilities, + capabilities_verified={}, + capabilities_override={}, + enabled=True, + catalog={}, + ) + session.add(model) + await session.commit() + await session.refresh(model) + return _model_read(model) + + @router.put("/models/{model_id}", response_model=ModelRead) async def update_model( model_id: int, diff --git a/surfsense_backend/app/schemas/__init__.py b/surfsense_backend/app/schemas/__init__.py index c14671c99..2a06eca5c 100644 --- a/surfsense_backend/app/schemas/__init__.py +++ b/surfsense_backend/app/schemas/__init__.py @@ -48,6 +48,7 @@ from .model_connections import ( ConnectionCreate, ConnectionRead, ConnectionUpdate, + ModelCreate, ModelRead, ModelRolesRead, ModelRolesUpdate, diff --git a/surfsense_backend/app/schemas/model_connections.py b/surfsense_backend/app/schemas/model_connections.py index 731064375..ea1ec4e88 100644 --- a/surfsense_backend/app/schemas/model_connections.py +++ b/surfsense_backend/app/schemas/model_connections.py @@ -65,6 +65,17 @@ class ConnectionUpdate(BaseModel): enabled: bool | None = None +class ModelCreate(BaseModel): + """Manually register a model id on a connection. + + For providers without a usable ``/models`` endpoint (Perplexity, MiniMax, + Azure deployments, etc.) or to pin a single model from a noisy provider. + """ + + model_id: str = Field(..., max_length=255) + display_name: str | None = Field(None, max_length=255) + + class ModelUpdate(BaseModel): display_name: str | None = Field(None, max_length=255) enabled: bool | None = None diff --git a/surfsense_backend/app/services/model_connection_service.py b/surfsense_backend/app/services/model_connection_service.py index 81090acaf..c8d2e8a5a 100644 --- a/surfsense_backend/app/services/model_connection_service.py +++ b/surfsense_backend/app/services/model_connection_service.py @@ -122,6 +122,17 @@ def _litellm_capabilities(model_string: str, model_id: str) -> dict[str, bool]: return capabilities +def _allowlist(conn: Connection) -> set[str]: + """Per-connection model-id allowlist stored in ``extra.model_ids``. + + Empty/absent means "no restriction" (discover everything), mirroring + OpenWebUI's behaviour. A non-empty list restricts discovery to those ids — + essential for providers like OpenRouter that expose hundreds of models. + """ + raw = (conn.extra or {}).get("model_ids") or [] + return {str(item).strip() for item in raw if str(item).strip()} + + def derive_capabilities(conn: Connection, model_id: str, metadata: dict | None = None) -> dict[str, bool]: metadata = metadata or {} if conn.protocol == ConnectionProtocol.OLLAMA: @@ -140,13 +151,15 @@ def derive_capabilities(conn: Connection, model_id: str, metadata: dict | None = async def discover_models(conn: Connection) -> list[dict[str, Any]]: + allowlist = _allowlist(conn) + if conn.protocol == ConnectionProtocol.OLLAMA: url = f"{conn.base_url.rstrip('/')}/api/tags" async with httpx.AsyncClient(timeout=DISCOVERY_TIMEOUT_SECONDS) as client: response = await client.get(url, headers=_auth_headers(conn)) response.raise_for_status() models = response.json().get("models", []) - return [ + results = [ { "model_id": item.get("model") or item.get("name"), "display_name": item.get("name") or item.get("model"), @@ -157,14 +170,13 @@ async def discover_models(conn: Connection) -> list[dict[str, Any]]: for item in models if item.get("model") or item.get("name") ] - - if conn.protocol == ConnectionProtocol.OPENAI_COMPATIBLE: + elif conn.protocol == ConnectionProtocol.OPENAI_COMPATIBLE: url = f"{ensure_v1(conn.base_url)}/models" async with httpx.AsyncClient(timeout=DISCOVERY_TIMEOUT_SECONDS) as client: response = await client.get(url, headers=_auth_headers(conn)) response.raise_for_status() models = response.json().get("data", []) - return [ + results = [ { "model_id": item.get("id"), "display_name": item.get("name") or item.get("id"), @@ -175,9 +187,13 @@ async def discover_models(conn: Connection) -> list[dict[str, Any]]: for item in models if item.get("id") ] + else: + # Native providers rely on curated/global catalog entries or manual rows. + return [] - # Native providers rely on curated/global catalog entries or manual rows. - return [] + if allowlist: + results = [item for item in results if item["model_id"] in allowlist] + return results async def test_model(conn: Connection, model: Model) -> VerifyResult: diff --git a/surfsense_web/atoms/model-connections/model-connections-mutation.atoms.ts b/surfsense_web/atoms/model-connections/model-connections-mutation.atoms.ts index 7d58a402c..612216bf2 100644 --- a/surfsense_web/atoms/model-connections/model-connections-mutation.atoms.ts +++ b/surfsense_web/atoms/model-connections/model-connections-mutation.atoms.ts @@ -3,6 +3,7 @@ import { toast } from "sonner"; import type { ConnectionCreateRequest, ConnectionUpdateRequest, + ModelCreateRequest, ModelRoles, ModelUpdateRequest, } from "@/contracts/types/model-connections.types"; @@ -67,8 +68,17 @@ export const verifyModelConnectionMutationAtom = atomWithMutation((get) => { mutationKey: ["model-connections", "verify"], mutationFn: (id: number) => modelConnectionsApiService.verifyConnection(id), onSuccess: (result) => { - if (result.ok) toast.success("Connection verified"); - else toast.error(result.message || "Connection failed"); + if (result.ok) { + toast.success("Connection verified"); + } else { + // Non-fatal: many providers lack a /models endpoint yet still serve + // chat. Guide the user to add model IDs manually instead of alarming. + toast.warning( + result.message + ? `${result.message} Chat may still work — add model IDs manually.` + : "Couldn't list models. Chat may still work — add model IDs manually." + ); + } invalidateModelConnections(searchSpaceId); }, onError: (error: Error) => toast.error(error.message || "Failed to verify connection"), @@ -88,6 +98,20 @@ export const discoverConnectionModelsMutationAtom = atomWithMutation((get) => { }; }); +export const addManualModelMutationAtom = atomWithMutation((get) => { + const searchSpaceId = Number(get(activeSearchSpaceIdAtom)); + return { + mutationKey: ["models", "add-manual"], + mutationFn: ({ connectionId, data }: { connectionId: number; data: ModelCreateRequest }) => + modelConnectionsApiService.addManualModel(connectionId, data), + onSuccess: () => { + toast.success("Model added"); + invalidateModelConnections(searchSpaceId); + }, + onError: (error: Error) => toast.error(error.message || "Failed to add model"), + }; +}); + export const updateModelMutationAtom = atomWithMutation((get) => { const searchSpaceId = Number(get(activeSearchSpaceIdAtom)); return { diff --git a/surfsense_web/components/settings/model-connections-settings.tsx b/surfsense_web/components/settings/model-connections-settings.tsx index 5fa4cccf7..e89fc3278 100644 --- a/surfsense_web/components/settings/model-connections-settings.tsx +++ b/surfsense_web/components/settings/model-connections-settings.tsx @@ -1,12 +1,14 @@ "use client"; import { useAtom, useAtomValue } from "jotai"; -import { CheckCircle2, PlugZap, RefreshCcw, XCircle } from "lucide-react"; +import { CheckCircle2, PlugZap, Plus, RefreshCcw, XCircle } from "lucide-react"; import { useMemo, useState } from "react"; import { + addManualModelMutationAtom, createModelConnectionMutationAtom, discoverConnectionModelsMutationAtom, testModelMutationAtom, + updateModelConnectionMutationAtom, updateModelMutationAtom, updateModelRolesMutationAtom, verifyModelConnectionMutationAtom, @@ -46,9 +48,16 @@ type Preset = { }; const PRESETS: Preset[] = [ + { id: "custom", label: "OpenAI-compatible (any URL)", protocol: "OPENAI_COMPATIBLE" }, { id: "openai", label: "OpenAI", protocol: "NATIVE", nativeProvider: "OPENAI" }, { id: "anthropic", label: "Anthropic", protocol: "NATIVE", nativeProvider: "ANTHROPIC" }, - { id: "openrouter", label: "OpenRouter", protocol: "NATIVE", nativeProvider: "OPENROUTER" }, + { + id: "openrouter", + label: "OpenRouter", + protocol: "NATIVE", + nativeProvider: "OPENROUTER", + baseUrl: "https://openrouter.ai/api/v1", + }, { id: "ollama", label: "Ollama", @@ -86,6 +95,22 @@ const PRESETS: Preset[] = [ }, ]; +// Free-text URL hints (datalist), mirroring OpenWebUI. These never restrict +// what the user can type — any OpenAI-compatible endpoint works. +const URL_SUGGESTIONS = [ + "https://api.openai.com/v1", + "https://api.anthropic.com/v1", + "https://openrouter.ai/api/v1", + "https://generativelanguage.googleapis.com/v1beta/openai", + "https://api.groq.com/openai/v1", + "https://api.mistral.ai/v1", + "https://api.deepseek.com/v1", + "https://api.x.ai/v1", + "http://host.docker.internal:11434", + "http://host.docker.internal:1234/v1", + "http://host.docker.internal:8000/v1", +]; + function modelLabel(model: ModelRead) { return model.display_name || model.model_id; } @@ -123,22 +148,183 @@ function flattenModels(connections: ConnectionRead[]) { ); } +function ConnectionCard({ connection }: { connection: ConnectionRead }) { + const verifyConnection = useAtomValue(verifyModelConnectionMutationAtom); + const discoverModels = useAtomValue(discoverConnectionModelsMutationAtom); + const updateConnection = useAtomValue(updateModelConnectionMutationAtom); + const addManualModel = useAtomValue(addManualModelMutationAtom); + const updateModel = useAtomValue(updateModelMutationAtom); + const testModel = useAtomValue(testModelMutationAtom); + + const allowlist = Array.isArray(connection.extra?.model_ids) + ? (connection.extra.model_ids as string[]) + : []; + const [allowlistText, setAllowlistText] = useState(allowlist.join(", ")); + const [manualModelId, setManualModelId] = useState(""); + + const providerLabel = connection.native_provider || connection.protocol; + const isLocal = connection.protocol === "OLLAMA" || !connection.base_url?.startsWith("https"); + + function saveAllowlist() { + const ids = allowlistText + .split(",") + .map((value) => value.trim()) + .filter(Boolean); + updateConnection.mutate({ + id: connection.id, + data: { extra: { ...(connection.extra ?? {}), model_ids: ids } }, + }); + } + + function addModel() { + const modelId = manualModelId.trim(); + if (!modelId) return; + addManualModel.mutate( + { connectionId: connection.id, data: { model_id: modelId } }, + { onSuccess: () => setManualModelId("") } + ); + } + + return ( +
+ {connection.last_error || "Could not list models."} Chat may still work — add model + IDs manually below. +
+ ) : null} + + {!isLocal ? ( ++ Leave empty to discover all models. Recommended for providers with large catalogs + (e.g. OpenRouter). +
++ Works with any OpenAI-compatible endpoint (OpenRouter, Together, Groq, vLLM, LM + Studio…). After adding, hit Discover to list models. +
) : null}{connection.last_error}
- ) : null} -