diff --git a/surfsense_web/atoms/model-connections/model-connections-mutation.atoms.ts b/surfsense_web/atoms/model-connections/model-connections-mutation.atoms.ts new file mode 100644 index 000000000..7d58a402c --- /dev/null +++ b/surfsense_web/atoms/model-connections/model-connections-mutation.atoms.ts @@ -0,0 +1,129 @@ +import { atomWithMutation } from "jotai-tanstack-query"; +import { toast } from "sonner"; +import type { + ConnectionCreateRequest, + ConnectionUpdateRequest, + ModelRoles, + ModelUpdateRequest, +} from "@/contracts/types/model-connections.types"; +import { modelConnectionsApiService } from "@/lib/apis/model-connections-api.service"; +import { cacheKeys } from "@/lib/query-client/cache-keys"; +import { queryClient } from "@/lib/query-client/client"; +import { activeSearchSpaceIdAtom } from "../search-spaces/search-space-query.atoms"; + +function invalidateModelConnections(searchSpaceId: number) { + queryClient.invalidateQueries({ + queryKey: cacheKeys.modelConnections.all(searchSpaceId), + }); + queryClient.invalidateQueries({ + queryKey: cacheKeys.modelConnections.roles(searchSpaceId), + }); +} + +export const createModelConnectionMutationAtom = atomWithMutation((get) => { + const searchSpaceId = Number(get(activeSearchSpaceIdAtom)); + return { + mutationKey: ["model-connections", "create"], + mutationFn: (request: ConnectionCreateRequest) => + modelConnectionsApiService.createConnection(request), + onSuccess: () => { + toast.success("Connection created"); + invalidateModelConnections(searchSpaceId); + }, + onError: (error: Error) => toast.error(error.message || "Failed to create connection"), + }; +}); + +export const updateModelConnectionMutationAtom = atomWithMutation((get) => { + const searchSpaceId = Number(get(activeSearchSpaceIdAtom)); + return { + mutationKey: ["model-connections", "update"], + mutationFn: ({ id, data }: { id: number; data: ConnectionUpdateRequest }) => + modelConnectionsApiService.updateConnection(id, data), + onSuccess: () => { + toast.success("Connection updated"); + invalidateModelConnections(searchSpaceId); + }, + onError: (error: Error) => toast.error(error.message || "Failed to update connection"), + }; +}); + +export const deleteModelConnectionMutationAtom = atomWithMutation((get) => { + const searchSpaceId = Number(get(activeSearchSpaceIdAtom)); + return { + mutationKey: ["model-connections", "delete"], + mutationFn: (id: number) => modelConnectionsApiService.deleteConnection(id), + onSuccess: () => { + toast.success("Connection deleted"); + invalidateModelConnections(searchSpaceId); + }, + onError: (error: Error) => toast.error(error.message || "Failed to delete connection"), + }; +}); + +export const verifyModelConnectionMutationAtom = atomWithMutation((get) => { + const searchSpaceId = Number(get(activeSearchSpaceIdAtom)); + return { + mutationKey: ["model-connections", "verify"], + mutationFn: (id: number) => modelConnectionsApiService.verifyConnection(id), + onSuccess: (result) => { + if (result.ok) toast.success("Connection verified"); + else toast.error(result.message || "Connection failed"); + invalidateModelConnections(searchSpaceId); + }, + onError: (error: Error) => toast.error(error.message || "Failed to verify connection"), + }; +}); + +export const discoverConnectionModelsMutationAtom = atomWithMutation((get) => { + const searchSpaceId = Number(get(activeSearchSpaceIdAtom)); + return { + mutationKey: ["model-connections", "discover"], + mutationFn: (id: number) => modelConnectionsApiService.discoverModels(id), + onSuccess: () => { + toast.success("Models discovered"); + invalidateModelConnections(searchSpaceId); + }, + onError: (error: Error) => toast.error(error.message || "Failed to discover models"), + }; +}); + +export const updateModelMutationAtom = atomWithMutation((get) => { + const searchSpaceId = Number(get(activeSearchSpaceIdAtom)); + return { + mutationKey: ["models", "update"], + mutationFn: ({ id, data }: { id: number; data: ModelUpdateRequest }) => + modelConnectionsApiService.updateModel(id, data), + onSuccess: () => invalidateModelConnections(searchSpaceId), + onError: (error: Error) => toast.error(error.message || "Failed to update model"), + }; +}); + +export const testModelMutationAtom = atomWithMutation((get) => { + const searchSpaceId = Number(get(activeSearchSpaceIdAtom)); + return { + mutationKey: ["models", "test"], + mutationFn: (id: number) => modelConnectionsApiService.testModel(id), + onSuccess: (result) => { + if (result.ok) toast.success("Model test succeeded"); + else toast.error(result.message || "Model test failed"); + invalidateModelConnections(searchSpaceId); + }, + onError: (error: Error) => toast.error(error.message || "Failed to test model"), + }; +}); + +export const updateModelRolesMutationAtom = atomWithMutation((get) => { + const searchSpaceId = Number(get(activeSearchSpaceIdAtom)); + return { + mutationKey: ["model-roles", "update"], + mutationFn: (roles: ModelRoles) => + modelConnectionsApiService.updateModelRoles(searchSpaceId, roles), + onSuccess: () => { + queryClient.invalidateQueries({ + queryKey: cacheKeys.modelConnections.roles(searchSpaceId), + }); + }, + onError: (error: Error) => toast.error(error.message || "Failed to update model roles"), + }; +}); diff --git a/surfsense_web/atoms/model-connections/model-connections-query.atoms.ts b/surfsense_web/atoms/model-connections/model-connections-query.atoms.ts new file mode 100644 index 000000000..617ffe124 --- /dev/null +++ b/surfsense_web/atoms/model-connections/model-connections-query.atoms.ts @@ -0,0 +1,32 @@ +import { atomWithQuery } from "jotai-tanstack-query"; +import { modelConnectionsApiService } from "@/lib/apis/model-connections-api.service"; +import { getBearerToken } from "@/lib/auth-utils"; +import { cacheKeys } from "@/lib/query-client/cache-keys"; +import { activeSearchSpaceIdAtom } from "../search-spaces/search-space-query.atoms"; + +export const globalModelConnectionsAtom = atomWithQuery(() => ({ + queryKey: cacheKeys.modelConnections.global(), + enabled: !!getBearerToken(), + staleTime: 10 * 60 * 1000, + queryFn: () => modelConnectionsApiService.getGlobalConnections(), +})); + +export const modelConnectionsAtom = atomWithQuery((get) => { + const searchSpaceId = Number(get(activeSearchSpaceIdAtom)); + return { + queryKey: cacheKeys.modelConnections.all(searchSpaceId), + enabled: !!searchSpaceId, + staleTime: 5 * 60 * 1000, + queryFn: () => modelConnectionsApiService.getConnections(searchSpaceId), + }; +}); + +export const modelRolesAtom = atomWithQuery((get) => { + const searchSpaceId = Number(get(activeSearchSpaceIdAtom)); + return { + queryKey: cacheKeys.modelConnections.roles(searchSpaceId), + enabled: !!searchSpaceId, + staleTime: 5 * 60 * 1000, + queryFn: () => modelConnectionsApiService.getModelRoles(searchSpaceId), + }; +}); diff --git a/surfsense_web/contracts/types/model-connections.types.ts b/surfsense_web/contracts/types/model-connections.types.ts new file mode 100644 index 000000000..14f93c61a --- /dev/null +++ b/surfsense_web/contracts/types/model-connections.types.ts @@ -0,0 +1,98 @@ +import { z } from "zod"; + +export const connectionProtocolEnum = z.enum(["OLLAMA", "OPENAI_COMPATIBLE", "NATIVE"]); +export const connectionScopeEnum = z.enum(["GLOBAL", "SEARCH_SPACE", "USER"]); +export const modelSourceEnum = z.enum(["DISCOVERED", "MANUAL"]); + +export const modelCapabilities = z.object({ + chat: z.boolean().optional(), + vision: z.boolean().optional(), + image_gen: z.boolean().optional(), + embedding: z.boolean().optional(), + tools: z.boolean().optional(), +}); + +export const modelRead = z.object({ + id: z.number(), + connection_id: z.number(), + model_id: z.string(), + display_name: z.string().nullable().optional(), + source: z.union([modelSourceEnum, z.string()]), + capabilities: z.record(z.string(), z.any()).default({}), + capabilities_declared: z.record(z.string(), z.any()).default({}), + capabilities_verified: z.record(z.string(), z.any()).default({}), + capabilities_override: z.record(z.string(), z.any()).default({}), + embedding_dimension: z.number().nullable().optional(), + enabled: z.boolean(), + billing_tier: z.string().nullable().optional(), + catalog: z.record(z.string(), z.any()).default({}), + created_at: z.string().nullable().optional(), +}); + +export const connectionRead = z.object({ + id: z.number(), + protocol: z.union([connectionProtocolEnum, z.string()]), + native_provider: z.string().nullable().optional(), + base_url: z.string().nullable().optional(), + extra: z.record(z.string(), z.any()).default({}), + scope: z.union([connectionScopeEnum, z.string()]), + search_space_id: z.number().nullable().optional(), + user_id: z.string().nullable().optional(), + enabled: z.boolean(), + has_api_key: z.boolean(), + last_verified_at: z.string().nullable().optional(), + last_status: z.string().nullable().optional(), + last_error: z.string().nullable().optional(), + models: z.array(modelRead).default([]), + created_at: z.string().nullable().optional(), +}); + +export const connectionCreateRequest = z.object({ + protocol: connectionProtocolEnum, + native_provider: z.string().nullable().optional(), + base_url: z.string().nullable().optional(), + api_key: z.string().nullable().optional(), + extra: z.record(z.string(), z.any()).default({}), + scope: connectionScopeEnum.default("SEARCH_SPACE"), + search_space_id: z.number().nullable().optional(), + enabled: z.boolean().default(true), +}); + +export const connectionUpdateRequest = z.object({ + native_provider: z.string().nullable().optional(), + base_url: z.string().nullable().optional(), + api_key: z.string().nullable().optional(), + extra: z.record(z.string(), z.any()).optional(), + enabled: z.boolean().optional(), +}); + +export const modelUpdateRequest = z.object({ + display_name: z.string().nullable().optional(), + enabled: z.boolean().optional(), + capabilities_override: z.record(z.string(), z.any()).optional(), +}); + +export const verifyConnectionResponse = z.object({ + status: z.string(), + ok: z.boolean(), + message: z.string().default(""), +}); + +export const modelRoles = z.object({ + chat_model_id: z.number().nullable().optional(), + vision_model_id: z.number().nullable().optional(), + image_gen_model_id: z.number().nullable().optional(), +}); + +export const connectionListResponse = z.array(connectionRead); +export const modelListResponse = z.array(modelRead); + +export type ConnectionProtocol = z.infer; +export type ConnectionScope = z.infer; +export type ModelRead = z.infer; +export type ConnectionRead = z.infer; +export type ConnectionCreateRequest = z.infer; +export type ConnectionUpdateRequest = z.infer; +export type ModelUpdateRequest = z.infer; +export type ModelRoles = z.infer; +export type VerifyConnectionResponse = z.infer; diff --git a/surfsense_web/lib/apis/model-connections-api.service.ts b/surfsense_web/lib/apis/model-connections-api.service.ts new file mode 100644 index 000000000..ca92ad11b --- /dev/null +++ b/surfsense_web/lib/apis/model-connections-api.service.ts @@ -0,0 +1,88 @@ +import { + type ConnectionCreateRequest, + type ConnectionUpdateRequest, + connectionCreateRequest, + connectionListResponse, + connectionRead, + connectionUpdateRequest, + type ModelRoles, + type ModelUpdateRequest, + modelListResponse, + modelRead, + modelRoles, + modelUpdateRequest, + verifyConnectionResponse, +} from "@/contracts/types/model-connections.types"; +import { ValidationError } from "../error"; +import { baseApiService } from "./base-api.service"; + +class ModelConnectionsApiService { + getGlobalConnections = async () => { + return baseApiService.get(`/api/v1/global-model-connections`, connectionListResponse); + }; + + getConnections = async (searchSpaceId: number) => { + return baseApiService.get( + `/api/v1/model-connections?search_space_id=${searchSpaceId}`, + connectionListResponse + ); + }; + + createConnection = async (request: ConnectionCreateRequest) => { + const parsed = connectionCreateRequest.safeParse(request); + if (!parsed.success) { + throw new ValidationError(parsed.error.issues.map((issue) => issue.message).join(", ")); + } + return baseApiService.post(`/api/v1/model-connections`, connectionRead, { + body: parsed.data, + }); + }; + + updateConnection = async (id: number, request: ConnectionUpdateRequest) => { + const parsed = connectionUpdateRequest.safeParse(request); + if (!parsed.success) { + throw new ValidationError(parsed.error.issues.map((issue) => issue.message).join(", ")); + } + return baseApiService.put(`/api/v1/model-connections/${id}`, connectionRead, { + body: parsed.data, + }); + }; + + deleteConnection = async (id: number) => { + return baseApiService.delete(`/api/v1/model-connections/${id}`, undefined); + }; + + verifyConnection = async (id: number) => { + return baseApiService.post(`/api/v1/model-connections/${id}/verify`, verifyConnectionResponse); + }; + + discoverModels = async (id: number) => { + return baseApiService.post(`/api/v1/model-connections/${id}/discover`, modelListResponse); + }; + + updateModel = async (id: number, request: ModelUpdateRequest) => { + const parsed = modelUpdateRequest.safeParse(request); + if (!parsed.success) { + throw new ValidationError(parsed.error.issues.map((issue) => issue.message).join(", ")); + } + return baseApiService.put(`/api/v1/models/${id}`, modelRead, { + body: parsed.data, + }); + }; + + testModel = async (id: number) => { + return baseApiService.post(`/api/v1/models/${id}/test`, verifyConnectionResponse); + }; + + getModelRoles = async (searchSpaceId: number) => { + return baseApiService.get(`/api/v1/search-spaces/${searchSpaceId}/model-roles`, modelRoles); + }; + + updateModelRoles = async (searchSpaceId: number, roles: ModelRoles) => { + return baseApiService.put(`/api/v1/search-spaces/${searchSpaceId}/model-roles`, modelRoles, { + body: roles, + }); + }; +} + +export const modelConnectionsApiService = new ModelConnectionsApiService(); diff --git a/surfsense_web/lib/query-client/cache-keys.ts b/surfsense_web/lib/query-client/cache-keys.ts index 6f8885d7e..558a73f95 100644 --- a/surfsense_web/lib/query-client/cache-keys.ts +++ b/surfsense_web/lib/query-client/cache-keys.ts @@ -44,6 +44,11 @@ export const cacheKeys = { global: () => ["new-llm-configs", "global"] as const, modelList: () => ["models", "catalogue"] as const, }, + modelConnections: { + all: (searchSpaceId: number) => ["model-connections", searchSpaceId] as const, + global: () => ["model-connections", "global"] as const, + roles: (searchSpaceId: number) => ["model-roles", searchSpaceId] as const, + }, imageGenConfigs: { all: (searchSpaceId: number) => ["image-gen-configs", searchSpaceId] as const, byId: (configId: number) => ["image-gen-configs", "detail", configId] as const,