diff --git a/surfsense_web/components/new-chat/model-selector.tsx b/surfsense_web/components/new-chat/model-selector.tsx index bad48d9ab..23c98889c 100644 --- a/surfsense_web/components/new-chat/model-selector.tsx +++ b/surfsense_web/components/new-chat/model-selector.tsx @@ -43,6 +43,8 @@ type ChatModel = ModelRead & { provider: string; }; +const AUTO_CHAT_MODEL_ID = 0; + function connectionLabel(connection: ConnectionRead) { if (connection.scope === "GLOBAL") return "Global"; return providerDisplay(connection.provider).name; @@ -74,6 +76,17 @@ function modelName(model: ChatModel) { return name; } +function filterChatModels(models: ChatModel[], search: string) { + const normalized = search.trim().toLowerCase(); + if (!normalized) return models; + return models.filter((model) => + [modelName(model), model.model_id, model.connectionLabel] + .join(" ") + .toLowerCase() + .includes(normalized) + ); +} + function groupedModels(models: ChatModel[]) { return models.reduce>((groups, model) => { const key = model.connectionLabel; @@ -100,24 +113,32 @@ export function ModelSelector({ const [{ data: roles }] = useAtom(modelRolesAtom); const updateRoles = useAtomValue(updateModelRolesMutationAtom); - const chatModels = useMemo(() => { - const normalized = search.trim().toLowerCase(); - const models = flattenChatModels([...globalConnections, ...connections]); - if (!normalized) return models; - return models.filter((model) => - [modelName(model), model.model_id, model.connectionLabel] - .join(" ") - .toLowerCase() - .includes(normalized) - ); - }, [globalConnections, connections, search]); + const allChatModels = useMemo( + () => flattenChatModels([...globalConnections, ...connections]), + [globalConnections, connections] + ); - const selected = chatModels.find((model) => model.id === roles?.chat_model_id); - const groups = groupedModels(chatModels); + const visibleChatModels = useMemo( + () => filterChatModels(allChatModels, search), + [allChatModels, search] + ); + const chatModelsById = useMemo( + () => new Map(allChatModels.map((model) => [model.id, model])), + [allChatModels] + ); + const selectedModelId = roles?.chat_model_id ?? AUTO_CHAT_MODEL_ID; + const selected = chatModelsById.get(selectedModelId); + const groups = useMemo(() => groupedModels(visibleChatModels), [visibleChatModels]); const loading = globalLoading || connectionsLoading; + function handleOpenChange(nextOpen: boolean) { + if (!nextOpen) setSearch(""); + setOpen(nextOpen); + } + function selectModel(modelId: number) { updateRoles.mutate({ chat_model_id: modelId }); + setSearch(""); setOpen(false); requestAnimationFrame(() => { onChatModelSelected?.(); @@ -160,7 +181,7 @@ export function ModelSelector({ {loading ? (
@@ -267,7 +288,7 @@ export function ModelSelector({ if (isMobile) { return ( - + {trigger} @@ -281,7 +302,7 @@ export function ModelSelector({ } return ( - + {trigger} {content}