diff --git a/surfsense_web/components/new-chat/model-selector.tsx b/surfsense_web/components/new-chat/model-selector.tsx index 0f5f50849..90226dde5 100644 --- a/surfsense_web/components/new-chat/model-selector.tsx +++ b/surfsense_web/components/new-chat/model-selector.tsx @@ -1,8 +1,10 @@ "use client"; import { useAtom, useAtomValue } from "jotai"; -import { Check, ChevronDown, Cpu, ImageOff, Search, Settings2, Zap } from "lucide-react"; -import { useMemo, useState } from "react"; +import { Check, ChevronDown, Cpu, Search, Settings2, Zap } from "lucide-react"; +import { useRouter } from "next/navigation"; +import type { UIEvent } from "react"; +import { useCallback, useMemo, useState } from "react"; import { updateModelRolesMutationAtom } from "@/atoms/model-connections/model-connections-mutation.atoms"; import { globalModelConnectionsAtom, @@ -23,41 +25,30 @@ import { Input } from "@/components/ui/input"; import { Popover, PopoverContent, PopoverTrigger } from "@/components/ui/popover"; import { Spinner } from "@/components/ui/spinner"; import type { ConnectionRead, ModelRead } from "@/contracts/types/model-connections.types"; -import type { - GlobalImageGenConfig, - GlobalNewLLMConfig, - GlobalVisionLLMConfig, - ImageGenerationConfig, - NewLLMConfigPublic, - VisionLLMConfig, -} from "@/contracts/types/new-llm-config.types"; import { useIsMobile } from "@/hooks/use-mobile"; import { getProviderIcon } from "@/lib/provider-icons"; import { cn } from "@/lib/utils"; +import { providerDisplay } from "../settings/model-connections/provider-metadata"; interface ModelSelectorProps { - onEditLLM: (config: NewLLMConfigPublic | GlobalNewLLMConfig, isGlobal: boolean) => void; - onAddNewLLM: (provider?: string) => void; - onEditImage?: (config: ImageGenerationConfig | GlobalImageGenConfig, isGlobal: boolean) => void; - onAddNewImage?: (provider?: string) => void; - onEditVision?: (config: VisionLLMConfig | GlobalVisionLLMConfig, isGlobal: boolean) => void; - onAddNewVision?: (provider?: string) => void; + searchSpaceId: number; className?: string; } type ChatModel = ModelRead & { connectionId: number; connectionLabel: string; + connectionScope: string; provider: string; }; function modelName(model: ModelRead) { - return model.display_name || model.model_id; + return (model.display_name || model.model_id).replace(/\s+\(free\)$/i, ""); } function connectionLabel(connection: ConnectionRead) { - if (connection.scope === "GLOBAL") return "Hosted"; - return connection.provider; + if (connection.scope === "GLOBAL") return "Global"; + return providerDisplay(connection.provider).name; } function flattenChatModels(connections: ConnectionRead[]) { @@ -68,11 +59,16 @@ function flattenChatModels(connections: ConnectionRead[]) { ...model, connectionId: connection.id, connectionLabel: connectionLabel(connection), + connectionScope: connection.scope, provider: connection.provider, })) ); } +function isFreeGlobalModel(model: ChatModel) { + return model.connectionScope === "GLOBAL" && model.billing_tier?.toLowerCase() === "free"; +} + function groupedModels(models: ChatModel[]) { return models.reduce>((groups, model) => { const key = model.connectionLabel; @@ -83,23 +79,14 @@ function groupedModels(models: ChatModel[]) { } export function ModelSelector({ - onAddNewLLM, - onEditLLM, - onEditImage, - onAddNewImage, - onEditVision, - onAddNewVision, + searchSpaceId, className, }: ModelSelectorProps) { - void onEditLLM; - void onEditImage; - void onAddNewImage; - void onEditVision; - void onAddNewVision; - + const router = useRouter(); const isMobile = useIsMobile(); const [open, setOpen] = useState(false); const [search, setSearch] = useState(""); + const [scrollPos, setScrollPos] = useState<"top" | "middle" | "bottom">("top"); const [{ data: globalConnections = [], isLoading: globalLoading }] = useAtom( globalModelConnectionsAtom ); @@ -130,11 +117,18 @@ export function ModelSelector({ function manageModelConnections() { setOpen(false); - onAddNewLLM(); + router.push(`/dashboard/${searchSpaceId}/search-space-settings/models`); } + const handleScroll = useCallback((event: UIEvent) => { + const el = event.currentTarget; + const atTop = el.scrollTop <= 2; + const atBottom = el.scrollHeight - el.scrollTop - el.clientHeight <= 2; + setScrollPos(atTop ? "top" : atBottom ? "bottom" : "middle"); + }, []); + const content = ( -
+
@@ -146,7 +140,14 @@ export function ModelSelector({ />
-
+
@@ -228,6 +243,7 @@ export function ModelSelector({ size="sm" className={cn( "h-8 min-w-0 gap-2 rounded-md px-3 text-muted-foreground transition-colors", + "select-none", "hover:bg-foreground/10 hover:text-foreground", "data-[state=open]:bg-foreground/10 data-[state=open]:text-foreground", className @@ -263,7 +279,7 @@ export function ModelSelector({ return ( {trigger} - + {content}