diff --git a/surfsense_backend/app/config/vision_model_list_fallback.json b/surfsense_backend/app/config/vision_model_list_fallback.json new file mode 100644 index 000000000..830eb6517 --- /dev/null +++ b/surfsense_backend/app/config/vision_model_list_fallback.json @@ -0,0 +1,23 @@ +[ + {"value": "gpt-4o", "label": "GPT-4o", "provider": "OPENAI", "context_window": "128K"}, + {"value": "gpt-4o-mini", "label": "GPT-4o Mini", "provider": "OPENAI", "context_window": "128K"}, + {"value": "gpt-4-turbo", "label": "GPT-4 Turbo", "provider": "OPENAI", "context_window": "128K"}, + {"value": "claude-sonnet-4-20250514", "label": "Claude Sonnet 4", "provider": "ANTHROPIC", "context_window": "200K"}, + {"value": "claude-3-7-sonnet-20250219", "label": "Claude 3.7 Sonnet", "provider": "ANTHROPIC", "context_window": "200K"}, + {"value": "claude-3-5-sonnet-20241022", "label": "Claude 3.5 Sonnet", "provider": "ANTHROPIC", "context_window": "200K"}, + {"value": "claude-3-opus-20240229", "label": "Claude 3 Opus", "provider": "ANTHROPIC", "context_window": "200K"}, + {"value": "claude-3-haiku-20240307", "label": "Claude 3 Haiku", "provider": "ANTHROPIC", "context_window": "200K"}, + {"value": "gemini-2.5-flash", "label": "Gemini 2.5 Flash", "provider": "GOOGLE", "context_window": "1M"}, + {"value": "gemini-2.5-pro", "label": "Gemini 2.5 Pro", "provider": "GOOGLE", "context_window": "1M"}, + {"value": "gemini-2.0-flash", "label": "Gemini 2.0 Flash", "provider": "GOOGLE", "context_window": "1M"}, + {"value": "gemini-1.5-pro", "label": "Gemini 1.5 Pro", "provider": "GOOGLE", "context_window": "1M"}, + {"value": "gemini-1.5-flash", "label": "Gemini 1.5 Flash", "provider": "GOOGLE", "context_window": "1M"}, + {"value": "pixtral-large-latest", "label": "Pixtral Large", "provider": "MISTRAL", "context_window": "128K"}, + {"value": "pixtral-12b-2409", "label": "Pixtral 12B", "provider": "MISTRAL", "context_window": "128K"}, + {"value": "grok-2-vision-1212", "label": "Grok 2 Vision", "provider": "XAI", "context_window": "32K"}, + {"value": "llava", "label": "LLaVA", "provider": "OLLAMA"}, + {"value": "bakllava", "label": "BakLLaVA", "provider": "OLLAMA"}, + {"value": "llava-llama3", "label": "LLaVA Llama 3", "provider": "OLLAMA"}, + {"value": "llama-4-scout-17b-16e-instruct", "label": "Llama 4 Scout 17B", "provider": "GROQ", "context_window": "128K"}, + {"value": "meta-llama/Llama-4-Scout-17B-16E-Instruct", "label": "Llama 4 Scout 17B", "provider": "TOGETHER_AI", "context_window": "128K"} +] diff --git a/surfsense_backend/app/routes/vision_llm_routes.py b/surfsense_backend/app/routes/vision_llm_routes.py index 29d1a2757..eddd5e367 100644 --- a/surfsense_backend/app/routes/vision_llm_routes.py +++ b/surfsense_backend/app/routes/vision_llm_routes.py @@ -1,6 +1,7 @@ import logging from fastapi import APIRouter, Depends, HTTPException +from pydantic import BaseModel from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession @@ -17,6 +18,7 @@ from app.schemas import ( VisionLLMConfigRead, VisionLLMConfigUpdate, ) +from app.services.vision_model_list_service import get_vision_model_list from app.users import current_active_user from app.utils.rbac import check_permission @@ -24,6 +26,32 @@ router = APIRouter() logger = logging.getLogger(__name__) +# ============================================================================= +# Vision Model Catalogue (from OpenRouter, filtered for image-input models) +# ============================================================================= + + +class VisionModelListItem(BaseModel): + value: str + label: str + provider: str + context_window: str | None = None + + +@router.get("/vision-models", response_model=list[VisionModelListItem]) +async def list_vision_models( + user: User = Depends(current_active_user), +): + """Return vision-capable models sourced from OpenRouter (filtered by image input).""" + try: + return await get_vision_model_list() + except Exception as e: + logger.exception("Failed to fetch vision model list") + raise HTTPException( + status_code=500, detail=f"Failed to fetch vision model list: {e!s}" + ) from e + + # ============================================================================= # Global Vision LLM Configs (from YAML) # ============================================================================= diff --git a/surfsense_backend/app/services/vision_model_list_service.py b/surfsense_backend/app/services/vision_model_list_service.py new file mode 100644 index 000000000..09893dd06 --- /dev/null +++ b/surfsense_backend/app/services/vision_model_list_service.py @@ -0,0 +1,132 @@ +""" +Service for fetching and caching the vision-capable model list. + +Reuses the same OpenRouter public API and local fallback as the LLM model +list service, but filters for models that accept image input. +""" + +import json +import logging +import time +from pathlib import Path + +import httpx + +logger = logging.getLogger(__name__) + +OPENROUTER_API_URL = "https://openrouter.ai/api/v1/models" +FALLBACK_FILE = Path(__file__).parent.parent / "config" / "vision_model_list_fallback.json" +CACHE_TTL_SECONDS = 86400 # 24 hours + +_cache: list[dict] | None = None +_cache_timestamp: float = 0 + +OPENROUTER_SLUG_TO_VISION_PROVIDER: dict[str, str] = { + "openai": "OPENAI", + "anthropic": "ANTHROPIC", + "google": "GOOGLE", + "mistralai": "MISTRAL", + "x-ai": "XAI", +} + + +def _format_context_length(length: int | None) -> str | None: + if not length: + return None + if length >= 1_000_000: + return f"{length / 1_000_000:g}M" + if length >= 1_000: + return f"{length / 1_000:g}K" + return str(length) + + +async def _fetch_from_openrouter() -> list[dict] | None: + try: + async with httpx.AsyncClient(timeout=15) as client: + response = await client.get(OPENROUTER_API_URL) + response.raise_for_status() + data = response.json() + return data.get("data", []) + except Exception as e: + logger.warning("Failed to fetch from OpenRouter API for vision models: %s", e) + return None + + +def _load_fallback() -> list[dict]: + try: + with open(FALLBACK_FILE, encoding="utf-8") as f: + return json.load(f) + except Exception as e: + logger.error("Failed to load vision model fallback list: %s", e) + return [] + + +def _is_vision_model(model: dict) -> bool: + """Return True if the model accepts image input and outputs text.""" + arch = model.get("architecture", {}) + input_mods = arch.get("input_modalities", []) + output_mods = arch.get("output_modalities", []) + return "image" in input_mods and "text" in output_mods + + +def _process_vision_models(raw_models: list[dict]) -> list[dict]: + processed: list[dict] = [] + + for model in raw_models: + model_id: str = model.get("id", "") + name: str = model.get("name", "") + context_length = model.get("context_length") + + if "/" not in model_id: + continue + + if not _is_vision_model(model): + continue + + provider_slug, model_name = model_id.split("/", 1) + context_window = _format_context_length(context_length) + + processed.append( + { + "value": model_id, + "label": name, + "provider": "OPENROUTER", + "context_window": context_window, + } + ) + + native_provider = OPENROUTER_SLUG_TO_VISION_PROVIDER.get(provider_slug) + if native_provider: + if native_provider == "GOOGLE" and not model_name.startswith("gemini-"): + continue + + processed.append( + { + "value": model_name, + "label": name, + "provider": native_provider, + "context_window": context_window, + } + ) + + return processed + + +async def get_vision_model_list() -> list[dict]: + global _cache, _cache_timestamp + + if _cache is not None and (time.time() - _cache_timestamp) < CACHE_TTL_SECONDS: + return _cache + + raw_models = await _fetch_from_openrouter() + + if raw_models is None: + logger.info("Using fallback vision model list") + return _load_fallback() + + processed = _process_vision_models(raw_models) + + _cache = processed + _cache_timestamp = time.time() + + return processed diff --git a/surfsense_web/atoms/vision-llm-config/vision-llm-config-query.atoms.ts b/surfsense_web/atoms/vision-llm-config/vision-llm-config-query.atoms.ts index 53264fb24..906ce638f 100644 --- a/surfsense_web/atoms/vision-llm-config/vision-llm-config-query.atoms.ts +++ b/surfsense_web/atoms/vision-llm-config/vision-llm-config-query.atoms.ts @@ -1,4 +1,6 @@ import { atomWithQuery } from "jotai-tanstack-query"; +import type { LLMModel } from "@/contracts/enums/llm-models"; +import { VISION_MODELS } from "@/contracts/enums/vision-providers"; import { visionLLMConfigApiService } from "@/lib/apis/vision-llm-config-api.service"; import { cacheKeys } from "@/lib/query-client/cache-keys"; import { activeSearchSpaceIdAtom } from "../search-spaces/search-space-query.atoms"; @@ -25,3 +27,25 @@ export const globalVisionLLMConfigsAtom = atomWithQuery(() => { }, }; }); + +export const visionModelListAtom = atomWithQuery(() => { + return { + queryKey: cacheKeys.visionLLMConfigs.modelList(), + staleTime: 60 * 60 * 1000, + placeholderData: VISION_MODELS, + queryFn: async (): Promise => { + const data = await visionLLMConfigApiService.getModels(); + const dynamicModels = data.map((m) => ({ + value: m.value, + label: m.label, + provider: m.provider, + contextWindow: m.context_window ?? undefined, + })); + + const coveredProviders = new Set(dynamicModels.map((m) => m.provider)); + const staticFallbacks = VISION_MODELS.filter((m) => !coveredProviders.has(m.provider)); + + return [...dynamicModels, ...staticFallbacks]; + }, + }; +}); diff --git a/surfsense_web/components/shared/vision-config-dialog.tsx b/surfsense_web/components/shared/vision-config-dialog.tsx index d69750316..6a494e0a6 100644 --- a/surfsense_web/components/shared/vision-config-dialog.tsx +++ b/surfsense_web/components/shared/vision-config-dialog.tsx @@ -1,20 +1,30 @@ "use client"; import { useAtomValue } from "jotai"; -import { AlertCircle } from "lucide-react"; -import { useCallback, useEffect, useRef, useState } from "react"; +import { AlertCircle, Check, ChevronsUpDown } from "lucide-react"; +import { useCallback, useEffect, useMemo, useRef, useState } from "react"; import { toast } from "sonner"; import { updateLLMPreferencesMutationAtom } from "@/atoms/new-llm-config/new-llm-config-mutation.atoms"; import { createVisionLLMConfigMutationAtom, updateVisionLLMConfigMutationAtom, } from "@/atoms/vision-llm-config/vision-llm-config-mutation.atoms"; +import { visionModelListAtom } from "@/atoms/vision-llm-config/vision-llm-config-query.atoms"; import { Alert, AlertDescription } from "@/components/ui/alert"; import { Badge } from "@/components/ui/badge"; import { Button } from "@/components/ui/button"; +import { + Command, + CommandEmpty, + CommandGroup, + CommandInput, + CommandItem, + CommandList, +} from "@/components/ui/command"; import { Dialog, DialogContent, DialogTitle } from "@/components/ui/dialog"; import { Input } from "@/components/ui/input"; import { Label } from "@/components/ui/label"; +import { Popover, PopoverContent, PopoverTrigger } from "@/components/ui/popover"; import { Select, SelectContent, @@ -30,6 +40,7 @@ import type { VisionLLMConfig, VisionProvider, } from "@/contracts/types/new-llm-config.types"; +import { cn } from "@/lib/utils"; interface VisionConfigDialogProps { open: boolean; @@ -177,6 +188,14 @@ export function VisionConfigDialog({ } }, [config, isGlobal, searchSpaceId, updatePreferences, onOpenChange]); + const { data: dynamicModels } = useAtomValue(visionModelListAtom); + const [modelComboboxOpen, setModelComboboxOpen] = useState(false); + + const availableModels = useMemo( + () => (dynamicModels ?? []).filter((m) => m.provider === formData.provider), + [dynamicModels, formData.provider] + ); + const isFormValid = formData.name && formData.provider && formData.model_name && formData.api_key; const selectedProvider = VISION_PROVIDERS.find((p) => p.value === formData.provider); @@ -303,11 +322,92 @@ export function VisionConfigDialog({
- setFormData((p) => ({ ...p, model_name: e.target.value }))} - /> + + + + + + + + setFormData((p) => ({ ...p, model_name: val })) + } + /> + + +
+ {formData.model_name + ? `Using: "${formData.model_name}"` + : "Type your model name"} +
+
+ {availableModels.length > 0 && ( + + {availableModels + .filter( + (model) => + !formData.model_name || + model.value + .toLowerCase() + .includes(formData.model_name.toLowerCase()) || + model.label + .toLowerCase() + .includes(formData.model_name.toLowerCase()) + ) + .slice(0, 50) + .map((model) => ( + { + setFormData((p) => ({ + ...p, + model_name: value, + })); + setModelComboboxOpen(false); + }} + className="py-2" + > + +
+
{model.label}
+ {model.contextWindow && ( +
+ Context: {model.contextWindow} +
+ )} +
+
+ ))} +
+ )} +
+
+
+
diff --git a/surfsense_web/contracts/enums/vision-providers.ts b/surfsense_web/contracts/enums/vision-providers.ts index 260b03585..08be93b74 100644 --- a/surfsense_web/contracts/enums/vision-providers.ts +++ b/surfsense_web/contracts/enums/vision-providers.ts @@ -1,3 +1,5 @@ +import type { LLMModel } from "./llm-models"; + export interface VisionProviderInfo { value: string; label: string; @@ -100,3 +102,27 @@ export const VISION_PROVIDERS: VisionProviderInfo[] = [ description: "Custom OpenAI-compatible vision endpoint", }, ]; + +export const VISION_MODELS: LLMModel[] = [ + { value: "gpt-4o", label: "GPT-4o", provider: "OPENAI", contextWindow: "128K" }, + { value: "gpt-4o-mini", label: "GPT-4o Mini", provider: "OPENAI", contextWindow: "128K" }, + { value: "gpt-4-turbo", label: "GPT-4 Turbo", provider: "OPENAI", contextWindow: "128K" }, + { value: "claude-sonnet-4-20250514", label: "Claude Sonnet 4", provider: "ANTHROPIC", contextWindow: "200K" }, + { value: "claude-3-7-sonnet-20250219", label: "Claude 3.7 Sonnet", provider: "ANTHROPIC", contextWindow: "200K" }, + { value: "claude-3-5-sonnet-20241022", label: "Claude 3.5 Sonnet", provider: "ANTHROPIC", contextWindow: "200K" }, + { value: "claude-3-opus-20240229", label: "Claude 3 Opus", provider: "ANTHROPIC", contextWindow: "200K" }, + { value: "claude-3-haiku-20240307", label: "Claude 3 Haiku", provider: "ANTHROPIC", contextWindow: "200K" }, + { value: "gemini-2.5-flash", label: "Gemini 2.5 Flash", provider: "GOOGLE", contextWindow: "1M" }, + { value: "gemini-2.5-pro", label: "Gemini 2.5 Pro", provider: "GOOGLE", contextWindow: "1M" }, + { value: "gemini-2.0-flash", label: "Gemini 2.0 Flash", provider: "GOOGLE", contextWindow: "1M" }, + { value: "gemini-1.5-pro", label: "Gemini 1.5 Pro", provider: "GOOGLE", contextWindow: "1M" }, + { value: "gemini-1.5-flash", label: "Gemini 1.5 Flash", provider: "GOOGLE", contextWindow: "1M" }, + { value: "pixtral-large-latest", label: "Pixtral Large", provider: "MISTRAL", contextWindow: "128K" }, + { value: "pixtral-12b-2409", label: "Pixtral 12B", provider: "MISTRAL", contextWindow: "128K" }, + { value: "grok-2-vision-1212", label: "Grok 2 Vision", provider: "XAI", contextWindow: "32K" }, + { value: "llava", label: "LLaVA", provider: "OLLAMA" }, + { value: "bakllava", label: "BakLLaVA", provider: "OLLAMA" }, + { value: "llava-llama3", label: "LLaVA Llama 3", provider: "OLLAMA" }, + { value: "llama-4-scout-17b-16e-instruct", label: "Llama 4 Scout 17B", provider: "GROQ", contextWindow: "128K" }, + { value: "meta-llama/Llama-4-Scout-17B-16E-Instruct", label: "Llama 4 Scout 17B", provider: "TOGETHER_AI", contextWindow: "128K" }, +]; diff --git a/surfsense_web/lib/apis/vision-llm-config-api.service.ts b/surfsense_web/lib/apis/vision-llm-config-api.service.ts index 4099c6b39..537cecbd1 100644 --- a/surfsense_web/lib/apis/vision-llm-config-api.service.ts +++ b/surfsense_web/lib/apis/vision-llm-config-api.service.ts @@ -4,6 +4,7 @@ import { createVisionLLMConfigResponse, deleteVisionLLMConfigResponse, getGlobalVisionLLMConfigsResponse, + getModelListResponse, getVisionLLMConfigsResponse, type UpdateVisionLLMConfigRequest, updateVisionLLMConfigRequest, @@ -13,6 +14,10 @@ import { ValidationError } from "../error"; import { baseApiService } from "./base-api.service"; class VisionLLMConfigApiService { + getModels = async () => { + return baseApiService.get(`/api/v1/vision-models`, getModelListResponse); + }; + getGlobalConfigs = async () => { return baseApiService.get( `/api/v1/global-vision-llm-configs`, diff --git a/surfsense_web/lib/query-client/cache-keys.ts b/surfsense_web/lib/query-client/cache-keys.ts index 04f348ff8..10aba7ef4 100644 --- a/surfsense_web/lib/query-client/cache-keys.ts +++ b/surfsense_web/lib/query-client/cache-keys.ts @@ -43,6 +43,7 @@ export const cacheKeys = { all: (searchSpaceId: number) => ["vision-llm-configs", searchSpaceId] as const, byId: (configId: number) => ["vision-llm-configs", "detail", configId] as const, global: () => ["vision-llm-configs", "global"] as const, + modelList: () => ["vision-models", "catalogue"] as const, }, auth: { user: ["auth", "user"] as const,