mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-06-06 20:15:17 +02:00
Add dynamic vision model list from OpenRouter with combobox selector
This commit is contained in:
parent
4dd1b6c64f
commit
26bffbcc47
8 changed files with 346 additions and 7 deletions
23
surfsense_backend/app/config/vision_model_list_fallback.json
Normal file
23
surfsense_backend/app/config/vision_model_list_fallback.json
Normal file
|
|
@ -0,0 +1,23 @@
|
|||
[
|
||||
{"value": "gpt-4o", "label": "GPT-4o", "provider": "OPENAI", "context_window": "128K"},
|
||||
{"value": "gpt-4o-mini", "label": "GPT-4o Mini", "provider": "OPENAI", "context_window": "128K"},
|
||||
{"value": "gpt-4-turbo", "label": "GPT-4 Turbo", "provider": "OPENAI", "context_window": "128K"},
|
||||
{"value": "claude-sonnet-4-20250514", "label": "Claude Sonnet 4", "provider": "ANTHROPIC", "context_window": "200K"},
|
||||
{"value": "claude-3-7-sonnet-20250219", "label": "Claude 3.7 Sonnet", "provider": "ANTHROPIC", "context_window": "200K"},
|
||||
{"value": "claude-3-5-sonnet-20241022", "label": "Claude 3.5 Sonnet", "provider": "ANTHROPIC", "context_window": "200K"},
|
||||
{"value": "claude-3-opus-20240229", "label": "Claude 3 Opus", "provider": "ANTHROPIC", "context_window": "200K"},
|
||||
{"value": "claude-3-haiku-20240307", "label": "Claude 3 Haiku", "provider": "ANTHROPIC", "context_window": "200K"},
|
||||
{"value": "gemini-2.5-flash", "label": "Gemini 2.5 Flash", "provider": "GOOGLE", "context_window": "1M"},
|
||||
{"value": "gemini-2.5-pro", "label": "Gemini 2.5 Pro", "provider": "GOOGLE", "context_window": "1M"},
|
||||
{"value": "gemini-2.0-flash", "label": "Gemini 2.0 Flash", "provider": "GOOGLE", "context_window": "1M"},
|
||||
{"value": "gemini-1.5-pro", "label": "Gemini 1.5 Pro", "provider": "GOOGLE", "context_window": "1M"},
|
||||
{"value": "gemini-1.5-flash", "label": "Gemini 1.5 Flash", "provider": "GOOGLE", "context_window": "1M"},
|
||||
{"value": "pixtral-large-latest", "label": "Pixtral Large", "provider": "MISTRAL", "context_window": "128K"},
|
||||
{"value": "pixtral-12b-2409", "label": "Pixtral 12B", "provider": "MISTRAL", "context_window": "128K"},
|
||||
{"value": "grok-2-vision-1212", "label": "Grok 2 Vision", "provider": "XAI", "context_window": "32K"},
|
||||
{"value": "llava", "label": "LLaVA", "provider": "OLLAMA"},
|
||||
{"value": "bakllava", "label": "BakLLaVA", "provider": "OLLAMA"},
|
||||
{"value": "llava-llama3", "label": "LLaVA Llama 3", "provider": "OLLAMA"},
|
||||
{"value": "llama-4-scout-17b-16e-instruct", "label": "Llama 4 Scout 17B", "provider": "GROQ", "context_window": "128K"},
|
||||
{"value": "meta-llama/Llama-4-Scout-17B-16E-Instruct", "label": "Llama 4 Scout 17B", "provider": "TOGETHER_AI", "context_window": "128K"}
|
||||
]
|
||||
|
|
@ -1,6 +1,7 @@
|
|||
import logging
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
|
|
@ -17,6 +18,7 @@ from app.schemas import (
|
|||
VisionLLMConfigRead,
|
||||
VisionLLMConfigUpdate,
|
||||
)
|
||||
from app.services.vision_model_list_service import get_vision_model_list
|
||||
from app.users import current_active_user
|
||||
from app.utils.rbac import check_permission
|
||||
|
||||
|
|
@ -24,6 +26,32 @@ router = APIRouter()
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Vision Model Catalogue (from OpenRouter, filtered for image-input models)
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class VisionModelListItem(BaseModel):
|
||||
value: str
|
||||
label: str
|
||||
provider: str
|
||||
context_window: str | None = None
|
||||
|
||||
|
||||
@router.get("/vision-models", response_model=list[VisionModelListItem])
|
||||
async def list_vision_models(
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""Return vision-capable models sourced from OpenRouter (filtered by image input)."""
|
||||
try:
|
||||
return await get_vision_model_list()
|
||||
except Exception as e:
|
||||
logger.exception("Failed to fetch vision model list")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to fetch vision model list: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Global Vision LLM Configs (from YAML)
|
||||
# =============================================================================
|
||||
|
|
|
|||
132
surfsense_backend/app/services/vision_model_list_service.py
Normal file
132
surfsense_backend/app/services/vision_model_list_service.py
Normal file
|
|
@ -0,0 +1,132 @@
|
|||
"""
|
||||
Service for fetching and caching the vision-capable model list.
|
||||
|
||||
Reuses the same OpenRouter public API and local fallback as the LLM model
|
||||
list service, but filters for models that accept image input.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import httpx
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
OPENROUTER_API_URL = "https://openrouter.ai/api/v1/models"
|
||||
FALLBACK_FILE = Path(__file__).parent.parent / "config" / "vision_model_list_fallback.json"
|
||||
CACHE_TTL_SECONDS = 86400 # 24 hours
|
||||
|
||||
_cache: list[dict] | None = None
|
||||
_cache_timestamp: float = 0
|
||||
|
||||
OPENROUTER_SLUG_TO_VISION_PROVIDER: dict[str, str] = {
|
||||
"openai": "OPENAI",
|
||||
"anthropic": "ANTHROPIC",
|
||||
"google": "GOOGLE",
|
||||
"mistralai": "MISTRAL",
|
||||
"x-ai": "XAI",
|
||||
}
|
||||
|
||||
|
||||
def _format_context_length(length: int | None) -> str | None:
|
||||
if not length:
|
||||
return None
|
||||
if length >= 1_000_000:
|
||||
return f"{length / 1_000_000:g}M"
|
||||
if length >= 1_000:
|
||||
return f"{length / 1_000:g}K"
|
||||
return str(length)
|
||||
|
||||
|
||||
async def _fetch_from_openrouter() -> list[dict] | None:
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=15) as client:
|
||||
response = await client.get(OPENROUTER_API_URL)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
return data.get("data", [])
|
||||
except Exception as e:
|
||||
logger.warning("Failed to fetch from OpenRouter API for vision models: %s", e)
|
||||
return None
|
||||
|
||||
|
||||
def _load_fallback() -> list[dict]:
|
||||
try:
|
||||
with open(FALLBACK_FILE, encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
except Exception as e:
|
||||
logger.error("Failed to load vision model fallback list: %s", e)
|
||||
return []
|
||||
|
||||
|
||||
def _is_vision_model(model: dict) -> bool:
|
||||
"""Return True if the model accepts image input and outputs text."""
|
||||
arch = model.get("architecture", {})
|
||||
input_mods = arch.get("input_modalities", [])
|
||||
output_mods = arch.get("output_modalities", [])
|
||||
return "image" in input_mods and "text" in output_mods
|
||||
|
||||
|
||||
def _process_vision_models(raw_models: list[dict]) -> list[dict]:
|
||||
processed: list[dict] = []
|
||||
|
||||
for model in raw_models:
|
||||
model_id: str = model.get("id", "")
|
||||
name: str = model.get("name", "")
|
||||
context_length = model.get("context_length")
|
||||
|
||||
if "/" not in model_id:
|
||||
continue
|
||||
|
||||
if not _is_vision_model(model):
|
||||
continue
|
||||
|
||||
provider_slug, model_name = model_id.split("/", 1)
|
||||
context_window = _format_context_length(context_length)
|
||||
|
||||
processed.append(
|
||||
{
|
||||
"value": model_id,
|
||||
"label": name,
|
||||
"provider": "OPENROUTER",
|
||||
"context_window": context_window,
|
||||
}
|
||||
)
|
||||
|
||||
native_provider = OPENROUTER_SLUG_TO_VISION_PROVIDER.get(provider_slug)
|
||||
if native_provider:
|
||||
if native_provider == "GOOGLE" and not model_name.startswith("gemini-"):
|
||||
continue
|
||||
|
||||
processed.append(
|
||||
{
|
||||
"value": model_name,
|
||||
"label": name,
|
||||
"provider": native_provider,
|
||||
"context_window": context_window,
|
||||
}
|
||||
)
|
||||
|
||||
return processed
|
||||
|
||||
|
||||
async def get_vision_model_list() -> list[dict]:
|
||||
global _cache, _cache_timestamp
|
||||
|
||||
if _cache is not None and (time.time() - _cache_timestamp) < CACHE_TTL_SECONDS:
|
||||
return _cache
|
||||
|
||||
raw_models = await _fetch_from_openrouter()
|
||||
|
||||
if raw_models is None:
|
||||
logger.info("Using fallback vision model list")
|
||||
return _load_fallback()
|
||||
|
||||
processed = _process_vision_models(raw_models)
|
||||
|
||||
_cache = processed
|
||||
_cache_timestamp = time.time()
|
||||
|
||||
return processed
|
||||
|
|
@ -1,4 +1,6 @@
|
|||
import { atomWithQuery } from "jotai-tanstack-query";
|
||||
import type { LLMModel } from "@/contracts/enums/llm-models";
|
||||
import { VISION_MODELS } from "@/contracts/enums/vision-providers";
|
||||
import { visionLLMConfigApiService } from "@/lib/apis/vision-llm-config-api.service";
|
||||
import { cacheKeys } from "@/lib/query-client/cache-keys";
|
||||
import { activeSearchSpaceIdAtom } from "../search-spaces/search-space-query.atoms";
|
||||
|
|
@ -25,3 +27,25 @@ export const globalVisionLLMConfigsAtom = atomWithQuery(() => {
|
|||
},
|
||||
};
|
||||
});
|
||||
|
||||
export const visionModelListAtom = atomWithQuery(() => {
|
||||
return {
|
||||
queryKey: cacheKeys.visionLLMConfigs.modelList(),
|
||||
staleTime: 60 * 60 * 1000,
|
||||
placeholderData: VISION_MODELS,
|
||||
queryFn: async (): Promise<LLMModel[]> => {
|
||||
const data = await visionLLMConfigApiService.getModels();
|
||||
const dynamicModels = data.map((m) => ({
|
||||
value: m.value,
|
||||
label: m.label,
|
||||
provider: m.provider,
|
||||
contextWindow: m.context_window ?? undefined,
|
||||
}));
|
||||
|
||||
const coveredProviders = new Set(dynamicModels.map((m) => m.provider));
|
||||
const staticFallbacks = VISION_MODELS.filter((m) => !coveredProviders.has(m.provider));
|
||||
|
||||
return [...dynamicModels, ...staticFallbacks];
|
||||
},
|
||||
};
|
||||
});
|
||||
|
|
|
|||
|
|
@ -1,20 +1,30 @@
|
|||
"use client";
|
||||
|
||||
import { useAtomValue } from "jotai";
|
||||
import { AlertCircle } from "lucide-react";
|
||||
import { useCallback, useEffect, useRef, useState } from "react";
|
||||
import { AlertCircle, Check, ChevronsUpDown } from "lucide-react";
|
||||
import { useCallback, useEffect, useMemo, useRef, useState } from "react";
|
||||
import { toast } from "sonner";
|
||||
import { updateLLMPreferencesMutationAtom } from "@/atoms/new-llm-config/new-llm-config-mutation.atoms";
|
||||
import {
|
||||
createVisionLLMConfigMutationAtom,
|
||||
updateVisionLLMConfigMutationAtom,
|
||||
} from "@/atoms/vision-llm-config/vision-llm-config-mutation.atoms";
|
||||
import { visionModelListAtom } from "@/atoms/vision-llm-config/vision-llm-config-query.atoms";
|
||||
import { Alert, AlertDescription } from "@/components/ui/alert";
|
||||
import { Badge } from "@/components/ui/badge";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import {
|
||||
Command,
|
||||
CommandEmpty,
|
||||
CommandGroup,
|
||||
CommandInput,
|
||||
CommandItem,
|
||||
CommandList,
|
||||
} from "@/components/ui/command";
|
||||
import { Dialog, DialogContent, DialogTitle } from "@/components/ui/dialog";
|
||||
import { Input } from "@/components/ui/input";
|
||||
import { Label } from "@/components/ui/label";
|
||||
import { Popover, PopoverContent, PopoverTrigger } from "@/components/ui/popover";
|
||||
import {
|
||||
Select,
|
||||
SelectContent,
|
||||
|
|
@ -30,6 +40,7 @@ import type {
|
|||
VisionLLMConfig,
|
||||
VisionProvider,
|
||||
} from "@/contracts/types/new-llm-config.types";
|
||||
import { cn } from "@/lib/utils";
|
||||
|
||||
interface VisionConfigDialogProps {
|
||||
open: boolean;
|
||||
|
|
@ -177,6 +188,14 @@ export function VisionConfigDialog({
|
|||
}
|
||||
}, [config, isGlobal, searchSpaceId, updatePreferences, onOpenChange]);
|
||||
|
||||
const { data: dynamicModels } = useAtomValue(visionModelListAtom);
|
||||
const [modelComboboxOpen, setModelComboboxOpen] = useState(false);
|
||||
|
||||
const availableModels = useMemo(
|
||||
() => (dynamicModels ?? []).filter((m) => m.provider === formData.provider),
|
||||
[dynamicModels, formData.provider]
|
||||
);
|
||||
|
||||
const isFormValid = formData.name && formData.provider && formData.model_name && formData.api_key;
|
||||
const selectedProvider = VISION_PROVIDERS.find((p) => p.value === formData.provider);
|
||||
|
||||
|
|
@ -303,11 +322,92 @@ export function VisionConfigDialog({
|
|||
|
||||
<div className="space-y-2">
|
||||
<Label className="text-sm font-medium">Model Name *</Label>
|
||||
<Input
|
||||
placeholder={selectedProvider?.example?.split(",")[0]?.trim() || "e.g., gpt-4o"}
|
||||
value={formData.model_name}
|
||||
onChange={(e) => setFormData((p) => ({ ...p, model_name: e.target.value }))}
|
||||
/>
|
||||
<Popover open={modelComboboxOpen} onOpenChange={setModelComboboxOpen}>
|
||||
<PopoverTrigger asChild>
|
||||
<Button
|
||||
variant="outline"
|
||||
role="combobox"
|
||||
aria-expanded={modelComboboxOpen}
|
||||
className={cn(
|
||||
"w-full justify-between font-normal bg-transparent",
|
||||
!formData.model_name && "text-muted-foreground"
|
||||
)}
|
||||
>
|
||||
{formData.model_name || "Select a model"}
|
||||
<ChevronsUpDown className="ml-2 h-4 w-4 shrink-0 opacity-50" />
|
||||
</Button>
|
||||
</PopoverTrigger>
|
||||
<PopoverContent
|
||||
className="w-full p-0 bg-muted dark:border-neutral-700"
|
||||
align="start"
|
||||
>
|
||||
<Command shouldFilter={false} className="bg-transparent">
|
||||
<CommandInput
|
||||
placeholder={selectedProvider?.example || "Search model name"}
|
||||
value={formData.model_name}
|
||||
onValueChange={(val) =>
|
||||
setFormData((p) => ({ ...p, model_name: val }))
|
||||
}
|
||||
/>
|
||||
<CommandList className="max-h-[300px]">
|
||||
<CommandEmpty>
|
||||
<div className="py-3 text-center text-sm text-muted-foreground">
|
||||
{formData.model_name
|
||||
? `Using: "${formData.model_name}"`
|
||||
: "Type your model name"}
|
||||
</div>
|
||||
</CommandEmpty>
|
||||
{availableModels.length > 0 && (
|
||||
<CommandGroup heading="Suggested Models">
|
||||
{availableModels
|
||||
.filter(
|
||||
(model) =>
|
||||
!formData.model_name ||
|
||||
model.value
|
||||
.toLowerCase()
|
||||
.includes(formData.model_name.toLowerCase()) ||
|
||||
model.label
|
||||
.toLowerCase()
|
||||
.includes(formData.model_name.toLowerCase())
|
||||
)
|
||||
.slice(0, 50)
|
||||
.map((model) => (
|
||||
<CommandItem
|
||||
key={model.value}
|
||||
value={model.value}
|
||||
onSelect={(value) => {
|
||||
setFormData((p) => ({
|
||||
...p,
|
||||
model_name: value,
|
||||
}));
|
||||
setModelComboboxOpen(false);
|
||||
}}
|
||||
className="py-2"
|
||||
>
|
||||
<Check
|
||||
className={cn(
|
||||
"mr-2 h-4 w-4",
|
||||
formData.model_name === model.value
|
||||
? "opacity-100"
|
||||
: "opacity-0"
|
||||
)}
|
||||
/>
|
||||
<div>
|
||||
<div className="font-medium">{model.label}</div>
|
||||
{model.contextWindow && (
|
||||
<div className="text-xs text-muted-foreground">
|
||||
Context: {model.contextWindow}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</CommandItem>
|
||||
))}
|
||||
</CommandGroup>
|
||||
)}
|
||||
</CommandList>
|
||||
</Command>
|
||||
</PopoverContent>
|
||||
</Popover>
|
||||
</div>
|
||||
|
||||
<div className="space-y-2">
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
import type { LLMModel } from "./llm-models";
|
||||
|
||||
export interface VisionProviderInfo {
|
||||
value: string;
|
||||
label: string;
|
||||
|
|
@ -100,3 +102,27 @@ export const VISION_PROVIDERS: VisionProviderInfo[] = [
|
|||
description: "Custom OpenAI-compatible vision endpoint",
|
||||
},
|
||||
];
|
||||
|
||||
export const VISION_MODELS: LLMModel[] = [
|
||||
{ value: "gpt-4o", label: "GPT-4o", provider: "OPENAI", contextWindow: "128K" },
|
||||
{ value: "gpt-4o-mini", label: "GPT-4o Mini", provider: "OPENAI", contextWindow: "128K" },
|
||||
{ value: "gpt-4-turbo", label: "GPT-4 Turbo", provider: "OPENAI", contextWindow: "128K" },
|
||||
{ value: "claude-sonnet-4-20250514", label: "Claude Sonnet 4", provider: "ANTHROPIC", contextWindow: "200K" },
|
||||
{ value: "claude-3-7-sonnet-20250219", label: "Claude 3.7 Sonnet", provider: "ANTHROPIC", contextWindow: "200K" },
|
||||
{ value: "claude-3-5-sonnet-20241022", label: "Claude 3.5 Sonnet", provider: "ANTHROPIC", contextWindow: "200K" },
|
||||
{ value: "claude-3-opus-20240229", label: "Claude 3 Opus", provider: "ANTHROPIC", contextWindow: "200K" },
|
||||
{ value: "claude-3-haiku-20240307", label: "Claude 3 Haiku", provider: "ANTHROPIC", contextWindow: "200K" },
|
||||
{ value: "gemini-2.5-flash", label: "Gemini 2.5 Flash", provider: "GOOGLE", contextWindow: "1M" },
|
||||
{ value: "gemini-2.5-pro", label: "Gemini 2.5 Pro", provider: "GOOGLE", contextWindow: "1M" },
|
||||
{ value: "gemini-2.0-flash", label: "Gemini 2.0 Flash", provider: "GOOGLE", contextWindow: "1M" },
|
||||
{ value: "gemini-1.5-pro", label: "Gemini 1.5 Pro", provider: "GOOGLE", contextWindow: "1M" },
|
||||
{ value: "gemini-1.5-flash", label: "Gemini 1.5 Flash", provider: "GOOGLE", contextWindow: "1M" },
|
||||
{ value: "pixtral-large-latest", label: "Pixtral Large", provider: "MISTRAL", contextWindow: "128K" },
|
||||
{ value: "pixtral-12b-2409", label: "Pixtral 12B", provider: "MISTRAL", contextWindow: "128K" },
|
||||
{ value: "grok-2-vision-1212", label: "Grok 2 Vision", provider: "XAI", contextWindow: "32K" },
|
||||
{ value: "llava", label: "LLaVA", provider: "OLLAMA" },
|
||||
{ value: "bakllava", label: "BakLLaVA", provider: "OLLAMA" },
|
||||
{ value: "llava-llama3", label: "LLaVA Llama 3", provider: "OLLAMA" },
|
||||
{ value: "llama-4-scout-17b-16e-instruct", label: "Llama 4 Scout 17B", provider: "GROQ", contextWindow: "128K" },
|
||||
{ value: "meta-llama/Llama-4-Scout-17B-16E-Instruct", label: "Llama 4 Scout 17B", provider: "TOGETHER_AI", contextWindow: "128K" },
|
||||
];
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ import {
|
|||
createVisionLLMConfigResponse,
|
||||
deleteVisionLLMConfigResponse,
|
||||
getGlobalVisionLLMConfigsResponse,
|
||||
getModelListResponse,
|
||||
getVisionLLMConfigsResponse,
|
||||
type UpdateVisionLLMConfigRequest,
|
||||
updateVisionLLMConfigRequest,
|
||||
|
|
@ -13,6 +14,10 @@ import { ValidationError } from "../error";
|
|||
import { baseApiService } from "./base-api.service";
|
||||
|
||||
class VisionLLMConfigApiService {
|
||||
getModels = async () => {
|
||||
return baseApiService.get(`/api/v1/vision-models`, getModelListResponse);
|
||||
};
|
||||
|
||||
getGlobalConfigs = async () => {
|
||||
return baseApiService.get(
|
||||
`/api/v1/global-vision-llm-configs`,
|
||||
|
|
|
|||
|
|
@ -43,6 +43,7 @@ export const cacheKeys = {
|
|||
all: (searchSpaceId: number) => ["vision-llm-configs", searchSpaceId] as const,
|
||||
byId: (configId: number) => ["vision-llm-configs", "detail", configId] as const,
|
||||
global: () => ["vision-llm-configs", "global"] as const,
|
||||
modelList: () => ["vision-models", "catalogue"] as const,
|
||||
},
|
||||
auth: {
|
||||
user: ["auth", "user"] as const,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue