feat(onboarding): simplify BYOK to provider + key, fetch models from key

This commit is contained in:
Prakhar Pandey 2026-06-19 01:12:53 +05:30
parent 2ddec07712
commit 92b95f659e
5 changed files with 136 additions and 166 deletions

View file

@ -25,7 +25,7 @@ import { RunEvent } from '@x/shared/dist/runs.js';
import { ServiceEvent } from '@x/shared/dist/service-events.js'; import { ServiceEvent } from '@x/shared/dist/service-events.js';
import container from '@x/core/dist/di/container.js'; import container from '@x/core/dist/di/container.js';
import { listOnboardingModels } from '@x/core/dist/models/models-dev.js'; import { listOnboardingModels } from '@x/core/dist/models/models-dev.js';
import { testModelConnection } from '@x/core/dist/models/models.js'; import { testModelConnection, listModelsForProvider } from '@x/core/dist/models/models.js';
import { isSignedIn } from '@x/core/dist/account/account.js'; import { isSignedIn } from '@x/core/dist/account/account.js';
import { listGatewayModels } from '@x/core/dist/models/gateway.js'; import { listGatewayModels } from '@x/core/dist/models/gateway.js';
import type { IModelConfigRepo } from '@x/core/dist/models/repo.js'; import type { IModelConfigRepo } from '@x/core/dist/models/repo.js';
@ -659,6 +659,15 @@ export function setupIpcHandlers() {
'models:test': async (_event, args) => { 'models:test': async (_event, args) => {
return await testModelConnection(args.provider, args.model); return await testModelConnection(args.provider, args.model);
}, },
'models:listForProvider': async (_event, args) => {
try {
const models = await listModelsForProvider(args.provider);
return { success: true, models };
} catch (err) {
const message = err instanceof Error ? err.message : 'Failed to list models';
return { success: false, error: message };
}
},
'models:saveConfig': async (_event, args) => { 'models:saveConfig': async (_event, args) => {
const repo = container.resolve<IModelConfigRepo>('modelConfigRepo'); const repo = container.resolve<IModelConfigRepo>('modelConfigRepo');
await repo.setConfig(args); await repo.setConfig(args);

View file

@ -2,13 +2,6 @@ import { Loader2, CheckCircle2, ArrowLeft, X, Lightbulb } from "lucide-react"
import { motion } from "motion/react" import { motion } from "motion/react"
import { Button } from "@/components/ui/button" import { Button } from "@/components/ui/button"
import { Input } from "@/components/ui/input" import { Input } from "@/components/ui/input"
import {
Select,
SelectContent,
SelectItem,
SelectTrigger,
SelectValue,
} from "@/components/ui/select"
import { cn } from "@/lib/utils" import { cn } from "@/lib/utils"
import { import {
OpenAIIcon, OpenAIIcon,
@ -40,16 +33,22 @@ const moreProviders: Array<{ id: LlmProviderFlavor; name: string; description: s
export function LlmSetupStep({ state }: LlmSetupStepProps) { export function LlmSetupStep({ state }: LlmSetupStepProps) {
const { const {
llmProvider, setLlmProvider, modelsCatalog, modelsLoading, modelsError, llmProvider, setLlmProvider, modelsLoading, modelsError,
activeConfig, testState, setTestState, showApiKey, activeConfig, testState, setTestState, showApiKey,
showBaseURL, isLocalProvider, canTest, showMoreProviders, setShowMoreProviders, showBaseURL, canTest, showMoreProviders, setShowMoreProviders,
updateProviderConfig, handleTestAndSaveLlmConfig, handleBack, updateProviderConfig, handleTestAndSaveLlmConfig, handleBack,
upsellDismissed, setUpsellDismissed, handleSwitchToRowboat, upsellDismissed, setUpsellDismissed, handleSwitchToRowboat,
} = state } = state
const isMoreProvider = moreProviders.some(p => p.id === llmProvider) const isMoreProvider = moreProviders.some(p => p.id === llmProvider)
const modelsForProvider = modelsCatalog[llmProvider] || [] // Hosted providers (openai/anthropic/google) get a default model, so we only
const showModelInput = isLocalProvider || modelsForProvider.length === 0 // ask for a model on providers that truly need one (local/custom/gateway),
// or as a fallback if no model is set yet.
// Hosted providers (openai/anthropic/google) fetch their models from the API
// key on test, so they never need a manual model field. Only local/custom/
// gateway providers, where the user must specify a model, show the input.
const hostedProviders: LlmProviderFlavor[] = ["openai", "anthropic", "google"]
const showModelInput = !hostedProviders.includes(llmProvider)
const renderProviderCard = (provider: typeof primaryProviders[0], index: number) => { const renderProviderCard = (provider: typeof primaryProviders[0], index: number) => {
const isSelected = llmProvider === provider.id const isSelected = llmProvider === provider.id
@ -87,7 +86,7 @@ export function LlmSetupStep({ state }: LlmSetupStepProps) {
<div className="flex flex-col flex-1"> <div className="flex flex-col flex-1">
{/* Title */} {/* Title */}
<h2 className="text-3xl font-bold tracking-tight text-center mb-2"> <h2 className="text-3xl font-bold tracking-tight text-center mb-2">
Choose your model Choose your provider
</h2> </h2>
<p className="text-base text-muted-foreground text-center mb-6"> <p className="text-base text-muted-foreground text-center mb-6">
Select a provider and configure your API key Select a provider and configure your API key
@ -145,153 +144,33 @@ export function LlmSetupStep({ state }: LlmSetupStepProps) {
{/* Separator */} {/* Separator */}
<div className="h-px bg-border my-4" /> <div className="h-px bg-border my-4" />
{/* Model configuration */} {/* Provider configuration */}
<div className="space-y-4"> <div className="space-y-4">
<h3 className="text-sm font-semibold">Model Configuration</h3> {/* Cloud providers get a default model auto-selected; only local/custom
providers (no catalog) need a model here. Users can pick any of the
<div className="grid grid-cols-1 sm:grid-cols-2 gap-4"> provider's models later in the chat view. */}
<div className="space-y-2 min-w-0"> {showModelInput && (
<div className="space-y-2">
<label className="text-xs font-medium text-muted-foreground"> <label className="text-xs font-medium text-muted-foreground">
Assistant Model Model
</label> </label>
{modelsLoading ? ( {modelsLoading ? (
<div className="flex items-center gap-2 text-sm text-muted-foreground"> <div className="flex items-center gap-2 text-sm text-muted-foreground">
<Loader2 className="size-4 animate-spin" /> <Loader2 className="size-4 animate-spin" />
Loading... Loading...
</div> </div>
) : showModelInput ? ( ) : (
<Input <Input
value={activeConfig.model} value={activeConfig.model}
onChange={(e) => updateProviderConfig(llmProvider, { model: e.target.value })} onChange={(e) => updateProviderConfig(llmProvider, { model: e.target.value })}
placeholder="Enter model" placeholder="Enter model"
/> />
) : (
<Select
value={activeConfig.model}
onValueChange={(value) => updateProviderConfig(llmProvider, { model: value })}
>
<SelectTrigger className="w-full truncate">
<SelectValue placeholder="Select a model" />
</SelectTrigger>
<SelectContent>
{modelsForProvider.map((model) => (
<SelectItem key={model.id} value={model.id}>
{model.name || model.id}
</SelectItem>
))}
</SelectContent>
</Select>
)} )}
{modelsError && ( {modelsError && (
<div className="text-xs text-destructive">{modelsError}</div> <div className="text-xs text-destructive">{modelsError}</div>
)} )}
</div> </div>
)}
<div className="space-y-2 min-w-0">
<label className="text-xs font-medium text-muted-foreground">
Knowledge Graph Model
</label>
{modelsLoading ? (
<div className="flex items-center gap-2 text-sm text-muted-foreground">
<Loader2 className="size-4 animate-spin" />
Loading...
</div>
) : showModelInput ? (
<Input
value={activeConfig.knowledgeGraphModel}
onChange={(e) => updateProviderConfig(llmProvider, { knowledgeGraphModel: e.target.value })}
placeholder={activeConfig.model || "Enter model"}
/>
) : (
<Select
value={activeConfig.knowledgeGraphModel || "__same__"}
onValueChange={(value) => updateProviderConfig(llmProvider, { knowledgeGraphModel: value === "__same__" ? "" : value })}
>
<SelectTrigger className="w-full truncate">
<SelectValue placeholder="Select a model" />
</SelectTrigger>
<SelectContent>
<SelectItem value="__same__">Same as assistant</SelectItem>
{modelsForProvider.map((model) => (
<SelectItem key={model.id} value={model.id}>
{model.name || model.id}
</SelectItem>
))}
</SelectContent>
</Select>
)}
</div>
<div className="space-y-2 min-w-0">
<label className="text-xs font-medium text-muted-foreground">
Meeting Notes Model
</label>
{modelsLoading ? (
<div className="flex items-center gap-2 text-sm text-muted-foreground">
<Loader2 className="size-4 animate-spin" />
Loading...
</div>
) : showModelInput ? (
<Input
value={activeConfig.meetingNotesModel}
onChange={(e) => updateProviderConfig(llmProvider, { meetingNotesModel: e.target.value })}
placeholder={activeConfig.model || "Enter model"}
/>
) : (
<Select
value={activeConfig.meetingNotesModel || "__same__"}
onValueChange={(value) => updateProviderConfig(llmProvider, { meetingNotesModel: value === "__same__" ? "" : value })}
>
<SelectTrigger className="w-full truncate">
<SelectValue placeholder="Select a model" />
</SelectTrigger>
<SelectContent>
<SelectItem value="__same__">Same as assistant</SelectItem>
{modelsForProvider.map((model) => (
<SelectItem key={model.id} value={model.id}>
{model.name || model.id}
</SelectItem>
))}
</SelectContent>
</Select>
)}
</div>
<div className="space-y-2 min-w-0">
<label className="text-xs font-medium text-muted-foreground">
Track Block Model
</label>
{modelsLoading ? (
<div className="flex items-center gap-2 text-sm text-muted-foreground">
<Loader2 className="size-4 animate-spin" />
Loading...
</div>
) : showModelInput ? (
<Input
value={activeConfig.liveNoteAgentModel}
onChange={(e) => updateProviderConfig(llmProvider, { liveNoteAgentModel: e.target.value })}
placeholder={activeConfig.model || "Enter model"}
/>
) : (
<Select
value={activeConfig.liveNoteAgentModel || "__same__"}
onValueChange={(value) => updateProviderConfig(llmProvider, { liveNoteAgentModel: value === "__same__" ? "" : value })}
>
<SelectTrigger className="w-full truncate">
<SelectValue placeholder="Select a model" />
</SelectTrigger>
<SelectContent>
<SelectItem value="__same__">Same as assistant</SelectItem>
{modelsForProvider.map((model) => (
<SelectItem key={model.id} value={model.id}>
{model.name || model.id}
</SelectItem>
))}
</SelectContent>
</Select>
)}
</div>
</div>
{showApiKey && ( {showApiKey && (
<div className="space-y-2"> <div className="space-y-2">

View file

@ -98,7 +98,6 @@ export function useOnboardingState(open: boolean, onComplete: () => void) {
const showBaseURL = llmProvider === "ollama" || llmProvider === "openai-compatible" || llmProvider === "aigateway" const showBaseURL = llmProvider === "ollama" || llmProvider === "openai-compatible" || llmProvider === "aigateway"
const isLocalProvider = llmProvider === "ollama" || llmProvider === "openai-compatible" const isLocalProvider = llmProvider === "ollama" || llmProvider === "openai-compatible"
const canTest = const canTest =
activeConfig.model.trim().length > 0 &&
(!requiresApiKey || activeConfig.apiKey.trim().length > 0) && (!requiresApiKey || activeConfig.apiKey.trim().length > 0) &&
(!requiresBaseURL || activeConfig.baseURL.trim().length > 0) (!requiresBaseURL || activeConfig.baseURL.trim().length > 0)
@ -416,37 +415,45 @@ export function useOnboardingState(open: boolean, onComplete: () => void) {
try { try {
const apiKey = activeConfig.apiKey.trim() || undefined const apiKey = activeConfig.apiKey.trim() || undefined
const baseURL = activeConfig.baseURL.trim() || undefined const baseURL = activeConfig.baseURL.trim() || undefined
const model = activeConfig.model.trim() const provider = {
const knowledgeGraphModel = activeConfig.knowledgeGraphModel.trim() || undefined flavor: llmProvider,
const meetingNotesModel = activeConfig.meetingNotesModel.trim() || undefined apiKey,
const liveNoteAgentModel = activeConfig.liveNoteAgentModel.trim() || undefined baseURL,
const providerConfig = {
provider: {
flavor: llmProvider,
apiKey,
baseURL,
},
model,
knowledgeGraphModel,
meetingNotesModel,
liveNoteAgentModel,
} }
const result = await window.ipc.invoke("models:test", providerConfig)
if (result.success) { // Fetch the provider's models from the key — this both validates the
setTestState({ status: "success" }) // credentials and gives us the list to populate the chat picker.
await window.ipc.invoke("models:saveConfig", providerConfig) const result = await window.ipc.invoke("models:listForProvider", { provider })
window.dispatchEvent(new Event('models-config-changed')) if (!result.success) {
handleNext()
} else {
setTestState({ status: "error", error: result.error }) setTestState({ status: "error", error: result.error })
toast.error(result.error || "Connection test failed") toast.error(result.error || "Connection test failed")
return
} }
const models: string[] = result.models ?? []
const preferred = preferredDefaults[llmProvider]
const model =
(preferred && models.includes(preferred) && preferred) ||
models[0] ||
activeConfig.model.trim() ||
""
const providerConfig = {
provider,
model,
models,
}
setTestState({ status: "success" })
await window.ipc.invoke("models:saveConfig", providerConfig)
window.dispatchEvent(new Event('models-config-changed'))
handleNext()
} catch (error) { } catch (error) {
console.error("Connection test failed:", error) console.error("Connection test failed:", error)
setTestState({ status: "error", error: "Connection test failed" }) setTestState({ status: "error", error: "Connection test failed" })
toast.error("Connection test failed") toast.error("Connection test failed")
} }
}, [activeConfig.apiKey, activeConfig.baseURL, activeConfig.model, activeConfig.knowledgeGraphModel, activeConfig.meetingNotesModel, activeConfig.liveNoteAgentModel, canTest, llmProvider, handleNext]) }, [activeConfig.apiKey, activeConfig.baseURL, activeConfig.model, canTest, llmProvider, handleNext])
// Check connection status for all providers // Check connection status for all providers
const refreshAllStatuses = useCallback(async () => { const refreshAllStatuses = useCallback(async () => {

View file

@ -96,3 +96,68 @@ export async function testModelConnection(
clearTimeout(timeout); clearTimeout(timeout);
} }
} }
export async function listModelsForProvider(
providerConfig: z.infer<typeof Provider>,
timeoutMs = 8000,
): Promise<string[]> {
const { flavor, apiKey, baseURL } = providerConfig;
const controller = new AbortController();
const timeout = setTimeout(() => controller.abort(), timeoutMs);
try {
let url = "";
const headers: Record<string, string> = {};
switch (flavor) {
case "openai":
url = "https://api.openai.com/v1/models";
headers["Authorization"] = `Bearer ${apiKey}`;
break;
case "anthropic":
url = "https://api.anthropic.com/v1/models";
headers["x-api-key"] = apiKey ?? "";
headers["anthropic-version"] = "2023-06-01";
break;
case "google":
url = `https://generativelanguage.googleapis.com/v1beta/models?key=${apiKey ?? ""}`;
break;
case "openrouter":
url = "https://openrouter.ai/api/v1/models";
if (apiKey) headers["Authorization"] = `Bearer ${apiKey}`;
break;
case "ollama":
url = `${(baseURL ?? "http://localhost:11434").replace(/\/$/, "")}/api/tags`;
break;
case "openai-compatible":
case "aigateway":
url = `${(baseURL ?? "").replace(/\/$/, "")}/models`;
if (apiKey) headers["Authorization"] = `Bearer ${apiKey}`;
break;
default:
throw new Error(`Unsupported provider flavor: ${flavor}`);
}
const res = await fetch(url, { headers, signal: controller.signal });
if (!res.ok) {
const body = await res.text().catch(() => "");
throw new Error(`Failed to list models (${res.status}): ${body.slice(0, 200)}`);
}
const data = await res.json();
// Normalize each provider's response shape into a flat list of model id strings.
let ids: string[] = [];
if (flavor === "google") {
// { models: [{ name: "models/gemini-..." }] }
ids = (data.models ?? []).map((m: { name: string }) => m.name.replace(/^models\//, ""));
} else if (flavor === "ollama") {
// { models: [{ name: "llama3:latest" }] }
ids = (data.models ?? []).map((m: { name: string }) => m.name);
} else {
// OpenAI-shaped: { data: [{ id: "..." }] }
ids = (data.data ?? []).map((m: { id: string }) => m.id);
}
return ids.filter((id: string) => typeof id === "string" && id.length > 0);
} finally {
clearTimeout(timeout);
}
}

View file

@ -2,7 +2,7 @@ import { z } from 'zod';
import { RelPath, Encoding, Stat, DirEntry, ReaddirOptions, ReadFileResult, WorkspaceChangeEvent, WriteFileOptions, WriteFileResult, RemoveOptions } from './workspace.js'; import { RelPath, Encoding, Stat, DirEntry, ReaddirOptions, ReadFileResult, WorkspaceChangeEvent, WriteFileOptions, WriteFileResult, RemoveOptions } from './workspace.js';
import { ListToolsResponse } from './mcp.js'; import { ListToolsResponse } from './mcp.js';
import { AskHumanResponsePayload, CreateRunOptions, Run, ListRunsResponse, ToolPermissionAuthorizePayload } from './runs.js'; import { AskHumanResponsePayload, CreateRunOptions, Run, ListRunsResponse, ToolPermissionAuthorizePayload } from './runs.js';
import { LlmModelConfig } from './models.js'; import { LlmModelConfig, LlmProvider } from './models.js';
import { AgentScheduleConfig, AgentScheduleEntry } from './agent-schedule.js'; import { AgentScheduleConfig, AgentScheduleEntry } from './agent-schedule.js';
import { AgentScheduleState } from './agent-schedule-state.js'; import { AgentScheduleState } from './agent-schedule-state.js';
import { ServiceEvent } from './service-events.js'; import { ServiceEvent } from './service-events.js';
@ -361,6 +361,16 @@ const ipcSchemas = {
error: z.string().optional(), error: z.string().optional(),
}), }),
}, },
'models:listForProvider': {
req: z.object({
provider: LlmProvider,
}),
res: z.object({
success: z.boolean(),
models: z.array(z.string()).optional(),
error: z.string().optional(),
}),
},
'models:saveConfig': { 'models:saveConfig': {
req: LlmModelConfig, req: LlmModelConfig,
res: z.object({ res: z.object({