diff --git a/CLAUDE.md b/CLAUDE.md index c59ed5b2..db51cb63 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -104,6 +104,11 @@ pnpm uses symlinks for workspace packages. Electron Forge's dependency walker ca ## Common Tasks +### LLM configuration (single provider) +- Config file: `~/.rowboat/config/models.json` +- Schema: `{ provider: { flavor, apiKey?, baseURL?, headers? }, model: string }` +- Models catalog cache: `~/.rowboat/config/models.dev.json` (OpenAI/Anthropic/Google only) + ### Add a new shared type 1. Edit `apps/x/packages/shared/src/` 2. Run `cd apps/x && npm run deps` to rebuild @@ -133,7 +138,7 @@ cd apps/x && npm run deps && npm run lint | UI | React 19, Vite 7 | | Styling | TailwindCSS, Radix UI | | State | React hooks | -| AI | Vercel AI SDK, Anthropic/OpenAI/Google providers | +| AI | Vercel AI SDK, OpenAI/Anthropic/Google/OpenRouter providers, Vercel AI Gateway, Ollama, models.dev catalog | | IPC | Electron contextBridge | | Build | TypeScript 5.9, esbuild, Electron Forge | diff --git a/apps/x/apps/main/src/ipc.ts b/apps/x/apps/main/src/ipc.ts index 87997713..5a7a7bd9 100644 --- a/apps/x/apps/main/src/ipc.ts +++ b/apps/x/apps/main/src/ipc.ts @@ -17,6 +17,9 @@ import fs from 'node:fs/promises'; import z from 'zod'; import { RunEvent } from '@x/shared/dist/runs.js'; import container from '@x/core/dist/di/container.js'; +import { listOnboardingModels } from '@x/core/dist/models/models-dev.js'; +import { testModelConnection } from '@x/core/dist/models/models.js'; +import type { IModelConfigRepo } from '@x/core/dist/models/repo.js'; import { IGranolaConfigRepo } from '@x/core/dist/knowledge/granola/repo.js'; import { triggerSync as triggerGranolaSync } from '@x/core/dist/knowledge/granola/sync.js'; import { isOnboardingComplete, markOnboardingComplete } from '@x/core/dist/config/note_creation_config.js'; @@ -305,6 +308,17 @@ export function setupIpcHandlers() { 'runs:list': async (_event, args) => { return runsCore.listRuns(args.cursor); }, + 'models:list': async () => { + return await listOnboardingModels(); + }, + 'models:test': async (_event, args) => { + return await testModelConnection(args.provider, args.model); + }, + 'models:saveConfig': async (_event, args) => { + const repo = container.resolve('modelConfigRepo'); + await repo.setConfig(args); + return { success: true }; + }, 'oauth:connect': async (_event, args) => { return await connectProvider(args.provider, args.clientId); }, @@ -371,4 +385,4 @@ export function setupIpcHandlers() { return composioHandler.executeAction(args.actionSlug, args.toolkitSlug, args.input); }, }); -} \ No newline at end of file +} diff --git a/apps/x/apps/renderer/src/components/onboarding-modal.tsx b/apps/x/apps/renderer/src/components/onboarding-modal.tsx index a1621d84..1f664f10 100644 --- a/apps/x/apps/renderer/src/components/onboarding-modal.tsx +++ b/apps/x/apps/renderer/src/components/onboarding-modal.tsx @@ -13,6 +13,14 @@ import { } from "@/components/ui/dialog" import { Button } from "@/components/ui/button" import { Switch } from "@/components/ui/switch" +import { Input } from "@/components/ui/input" +import { + Select, + SelectContent, + SelectItem, + SelectTrigger, + SelectValue, +} from "@/components/ui/select" import { cn } from "@/lib/utils" import { ComposioApiKeyModal } from "@/components/composio-api-key-modal" import { GoogleClientIdModal } from "@/components/google-client-id-modal" @@ -30,11 +38,38 @@ interface OnboardingModalProps { onComplete: () => void } -type Step = 0 | 1 | 2 +type Step = 0 | 1 | 2 | 3 + +type LlmProviderFlavor = "openai" | "anthropic" | "google" | "openrouter" | "aigateway" | "ollama" | "openai-compatible" + +interface LlmModelOption { + id: string + name?: string + release_date?: string +} export function OnboardingModal({ open, onComplete }: OnboardingModalProps) { const [currentStep, setCurrentStep] = useState(0) + // LLM setup state + const [llmProvider, setLlmProvider] = useState("openai") + const [modelsCatalog, setModelsCatalog] = useState>({}) + const [modelsLoading, setModelsLoading] = useState(false) + const [modelsError, setModelsError] = useState(null) + const [providerConfigs, setProviderConfigs] = useState>({ + openai: { apiKey: "", baseURL: "", model: "" }, + anthropic: { apiKey: "", baseURL: "", model: "" }, + google: { apiKey: "", baseURL: "", model: "" }, + openrouter: { apiKey: "", baseURL: "", model: "" }, + aigateway: { apiKey: "", baseURL: "", model: "" }, + ollama: { apiKey: "", baseURL: "http://localhost:11434", model: "" }, + "openai-compatible": { apiKey: "", baseURL: "http://localhost:1234/v1", model: "" }, + }) + const [testState, setTestState] = useState<{ status: "idle" | "testing" | "success" | "error"; error?: string }>({ + status: "idle", + }) + const [savingLlmConfig, setSavingLlmConfig] = useState(false) + // OAuth provider states const [providers, setProviders] = useState([]) const [providersLoading, setProvidersLoading] = useState(true) @@ -51,6 +86,27 @@ export function OnboardingModal({ open, onComplete }: OnboardingModalProps) { const [slackLoading, setSlackLoading] = useState(true) const [slackConnecting, setSlackConnecting] = useState(false) + const updateProviderConfig = useCallback( + (provider: LlmProviderFlavor, updates: Partial<{ apiKey: string; baseURL: string; model: string }>) => { + setProviderConfigs(prev => ({ + ...prev, + [provider]: { ...prev[provider], ...updates }, + })) + setTestState({ status: "idle" }) + }, + [] + ) + + const activeConfig = providerConfigs[llmProvider] + const requiresApiKey = llmProvider === "openai" || llmProvider === "anthropic" || llmProvider === "google" || llmProvider === "openrouter" || llmProvider === "aigateway" + const requiresBaseURL = llmProvider === "ollama" || llmProvider === "openai-compatible" + const showBaseURL = llmProvider === "ollama" || llmProvider === "openai-compatible" || llmProvider === "aigateway" + const isLocalProvider = llmProvider === "ollama" || llmProvider === "openai-compatible" + const canTest = + activeConfig.model.trim().length > 0 && + (!requiresApiKey || activeConfig.apiKey.trim().length > 0) && + (!requiresBaseURL || activeConfig.baseURL.trim().length > 0) + // Track connected providers for the completion step const connectedProviders = Object.entries(providerStates) .filter(([, state]) => state.isConnected) @@ -75,6 +131,48 @@ export function OnboardingModal({ open, onComplete }: OnboardingModalProps) { loadProviders() }, [open]) + // Load LLM models catalog on open + useEffect(() => { + if (!open) return + + async function loadModels() { + try { + setModelsLoading(true) + setModelsError(null) + const result = await window.ipc.invoke("models:list", null) + const catalog: Record = {} + for (const provider of result.providers || []) { + catalog[provider.id] = provider.models || [] + } + setModelsCatalog(catalog) + } catch (error) { + console.error("Failed to load models catalog:", error) + setModelsError("Failed to load models list") + setModelsCatalog({}) + } finally { + setModelsLoading(false) + } + } + + loadModels() + }, [open]) + + // Initialize default models from catalog + useEffect(() => { + if (Object.keys(modelsCatalog).length === 0) return + setProviderConfigs(prev => { + const next = { ...prev } + const cloudProviders: LlmProviderFlavor[] = ["openai", "anthropic", "google"] + for (const provider of cloudProviders) { + const models = modelsCatalog[provider] + if (models?.length && !next[provider].model) { + next[provider] = { ...next[provider], model: models[0]?.id || "" } + } + } + return next + }) + }, [modelsCatalog]) + // Load Granola config const refreshGranolaConfig = useCallback(async () => { try { @@ -160,6 +258,69 @@ export function OnboardingModal({ open, onComplete }: OnboardingModalProps) { } }, [startSlackConnect]) + const handleNext = () => { + if (currentStep < 3) { + setCurrentStep((prev) => (prev + 1) as Step) + } + } + + const handleComplete = () => { + onComplete() + } + + const handleTestConnection = useCallback(async () => { + if (!canTest) return + setTestState({ status: "testing" }) + try { + const apiKey = activeConfig.apiKey.trim() || undefined + const baseURL = activeConfig.baseURL.trim() || undefined + const model = activeConfig.model.trim() + const result = await window.ipc.invoke("models:test", { + provider: { + flavor: llmProvider, + apiKey, + baseURL, + }, + model, + }) + if (result.success) { + setTestState({ status: "success" }) + toast.success("Connection successful") + } else { + setTestState({ status: "error", error: result.error }) + toast.error(result.error || "Connection test failed") + } + } catch (error) { + console.error("Connection test failed:", error) + setTestState({ status: "error", error: "Connection test failed" }) + toast.error("Connection test failed") + } + }, [activeConfig.apiKey, activeConfig.baseURL, activeConfig.model, canTest, llmProvider]) + + const handleSaveLlmConfig = useCallback(async () => { + if (testState.status !== "success") return + setSavingLlmConfig(true) + try { + const apiKey = activeConfig.apiKey.trim() || undefined + const baseURL = activeConfig.baseURL.trim() || undefined + const model = activeConfig.model.trim() + await window.ipc.invoke("models:saveConfig", { + provider: { + flavor: llmProvider, + apiKey, + baseURL, + }, + model, + }) + setSavingLlmConfig(false) + handleNext() + } catch (error) { + console.error("Failed to save LLM config:", error) + toast.error("Failed to save LLM settings") + setSavingLlmConfig(false) + } + }, [activeConfig.apiKey, activeConfig.baseURL, activeConfig.model, handleNext, llmProvider, testState.status]) + // Check connection status for all providers const refreshAllStatuses = useCallback(async () => { // Refresh Granola @@ -295,20 +456,10 @@ export function OnboardingModal({ open, onComplete }: OnboardingModalProps) { startConnect('google', clientId) }, [startConnect]) - const handleNext = () => { - if (currentStep < 2) { - setCurrentStep((prev) => (prev + 1) as Step) - } - } - - const handleComplete = () => { - onComplete() - } - // Step indicator component const StepIndicator = () => (
- {[0, 1, 2].map((step) => ( + {[0, 1, 2, 3].map((step) => (
) - // Step 1: Connect Accounts + // Step 1: LLM Setup + const LlmSetupStep = () => { + const providerOptions: Array<{ id: LlmProviderFlavor; name: string; description: string }> = [ + { id: "openai", name: "OpenAI", description: "Use your OpenAI API key" }, + { id: "anthropic", name: "Anthropic", description: "Use your Anthropic API key" }, + { id: "google", name: "Google", description: "Use your Google AI Studio key" }, + { id: "openrouter", name: "OpenRouter", description: "Access multiple models with one key" }, + { id: "aigateway", name: "AI Gateway (Vercel)", description: "Use Vercel's AI Gateway" }, + { id: "ollama", name: "Ollama (Local)", description: "Run a local model via Ollama" }, + { id: "openai-compatible", name: "OpenAI-Compatible", description: "Local or hosted OpenAI-compatible API" }, + ] + + const modelsForProvider = modelsCatalog[llmProvider] || [] + const showModelInput = isLocalProvider || modelsForProvider.length === 0 + + return ( +
+ + Choose your model + + Select your provider and model to power Rowboat’s AI. + + + +
+
+ Provider +
+ {providerOptions.map((provider) => ( + + ))} +
+
+ +
+ Model + {modelsLoading ? ( +
+ + Loading models... +
+ ) : showModelInput ? ( + updateProviderConfig(llmProvider, { model: e.target.value })} + placeholder="Enter model ID" + /> + ) : ( + + )} + {modelsError && ( +
{modelsError}
+ )} +
+ + {requiresApiKey && ( +
+ API Key + updateProviderConfig(llmProvider, { apiKey: e.target.value })} + placeholder="Paste your API key" + /> +
+ )} + + {showBaseURL && ( +
+ Base URL + updateProviderConfig(llmProvider, { baseURL: e.target.value })} + placeholder={ + llmProvider === "ollama" + ? "http://localhost:11434" + : llmProvider === "openai-compatible" + ? "http://localhost:1234/v1" + : "https://ai-gateway.vercel.sh/v1" + } + /> +
+ )} +
+ +
+ + {testState.status === "success" && ( + Connected + )} + {testState.status === "error" && ( + + {testState.error || "Test failed"} + + )} +
+ +
+ +
+
+ ) + } + + // Step 2: Connect Accounts const AccountConnectionStep = () => (
@@ -534,7 +834,7 @@ export function OnboardingModal({ open, onComplete }: OnboardingModalProps) {
) - // Step 2: Completion + // Step 3: Completion const CompletionStep = () => { const hasConnections = connectedProviders.length > 0 || granolaEnabled || slackConnected @@ -618,8 +918,9 @@ export function OnboardingModal({ open, onComplete }: OnboardingModalProps) { > {currentStep === 0 && } - {currentStep === 1 && } - {currentStep === 2 && } + {currentStep === 1 && } + {currentStep === 2 && } + {currentStep === 3 && } diff --git a/apps/x/packages/core/src/agents/runtime.ts b/apps/x/packages/core/src/agents/runtime.ts index 9a6b7dcb..0246cc2f 100644 --- a/apps/x/packages/core/src/agents/runtime.ts +++ b/apps/x/packages/core/src/agents/runtime.ts @@ -15,7 +15,7 @@ import { CopilotAgent } from "../application/assistant/agent.js"; import { isBlocked } from "../application/lib/command-executor.js"; import container from "../di/container.js"; import { IModelConfigRepo } from "../models/repo.js"; -import { getProvider } from "../models/models.js"; +import { createProvider } from "../models/models.js"; import { IAgentsRepo } from "./repo.js"; import { IMonotonicallyIncreasingIdGenerator } from "../application/lib/id-gen.js"; import { IBus } from "../application/lib/bus.js"; @@ -623,8 +623,8 @@ export async function* streamAgent({ const tools = await buildTools(agent); // set up provider + model - const provider = await getProvider(agent.provider); - const model = provider.languageModel(agent.model || modelConfig.defaults.model); + const provider = createProvider(modelConfig.provider); + const model = provider.languageModel(modelConfig.model); let loopCounter = 0; while (true) { diff --git a/apps/x/packages/core/src/models/models-dev.ts b/apps/x/packages/core/src/models/models-dev.ts new file mode 100644 index 00000000..6fecb694 --- /dev/null +++ b/apps/x/packages/core/src/models/models-dev.ts @@ -0,0 +1,174 @@ +import fs from "node:fs/promises"; +import path from "node:path"; +import z from "zod"; +import { WorkDir } from "../config/config.js"; + +const CACHE_PATH = path.join(WorkDir, "config", "models.dev.json"); +const CACHE_TTL_MS = 24 * 60 * 60 * 1000; + +const ModelsDevModel = z.object({ + id: z.string().optional(), + name: z.string().optional(), + release_date: z.string().optional(), + tool_call: z.boolean().optional(), + experimental: z.boolean().optional(), + status: z.enum(["alpha", "beta", "deprecated"]).optional(), +}).passthrough(); + +const ModelsDevProvider = z.object({ + id: z.string().optional(), + name: z.string(), + models: z.record(z.string(), ModelsDevModel), +}).passthrough(); + +const ModelsDevResponse = z.record(z.string(), ModelsDevProvider); + +type ProviderSummary = { + id: string; + name: string; + models: Array<{ + id: string; + name?: string; + release_date?: string; + }>; +}; + +type CacheFile = { + fetchedAt: string; + data: unknown; +}; + +async function readCache(): Promise { + try { + const raw = await fs.readFile(CACHE_PATH, "utf8"); + return JSON.parse(raw) as CacheFile; + } catch { + return null; + } +} + +async function writeCache(data: unknown): Promise { + const payload: CacheFile = { + fetchedAt: new Date().toISOString(), + data, + }; + await fs.writeFile(CACHE_PATH, JSON.stringify(payload, null, 2)); +} + +async function fetchModelsDev(): Promise { + const response = await fetch("https://models.dev/api.json", { + headers: { "User-Agent": "Rowboat" }, + }); + if (!response.ok) { + throw new Error(`models.dev fetch failed: ${response.status}`); + } + return response.json(); +} + +function isCacheFresh(fetchedAt: string): boolean { + const age = Date.now() - new Date(fetchedAt).getTime(); + return age < CACHE_TTL_MS; +} + +async function getModelsDevData(): Promise<{ data: z.infer; fetchedAt?: string }> { + const cached = await readCache(); + if (cached?.fetchedAt && isCacheFresh(cached.fetchedAt)) { + const parsed = ModelsDevResponse.safeParse(cached.data); + if (parsed.success) { + return { data: parsed.data, fetchedAt: cached.fetchedAt }; + } + } + + try { + const fresh = await fetchModelsDev(); + const parsed = ModelsDevResponse.parse(fresh); + await writeCache(parsed); + return { data: parsed, fetchedAt: new Date().toISOString() }; + } catch (error) { + if (cached) { + const parsed = ModelsDevResponse.safeParse(cached.data); + if (parsed.success) { + return { data: parsed.data, fetchedAt: cached.fetchedAt }; + } + } + throw error; + } +} + +function scoreProvider(flavor: string, id: string, name: string): number { + const normalizedId = id.toLowerCase(); + const normalizedName = name.toLowerCase(); + let score = 0; + if (normalizedId === flavor) score += 100; + if (normalizedName.includes(flavor)) score += 20; + if (flavor === "google") { + if (normalizedName.includes("gemini")) score += 10; + if (normalizedName.includes("vertex")) score -= 5; + } + return score; +} + +function pickProvider( + data: z.infer, + flavor: "openai" | "anthropic" | "google", +): z.infer | null { + if (data[flavor]) return data[flavor]; + let best: { score: number; provider: z.infer } | null = null; + for (const [id, provider] of Object.entries(data)) { + const s = scoreProvider(flavor, id, provider.name); + if (s <= 0) continue; + if (!best || s > best.score) { + best = { score: s, provider }; + } + } + return best?.provider ?? null; +} + +function isStableModel(model: z.infer): boolean { + if (model.experimental) return false; + if (model.status && ["alpha", "beta", "deprecated"].includes(model.status)) return false; + return true; +} + +function supportsToolCall(model: z.infer): boolean { + return model.tool_call === true; +} + +function normalizeModels(models: Record>): ProviderSummary["models"] { + const list = Object.entries(models) + .map(([id, model]) => ({ + id: model.id ?? id, + name: model.name, + release_date: model.release_date, + tool_call: model.tool_call, + experimental: model.experimental, + status: model.status, + })) + .filter((model) => isStableModel(model) && supportsToolCall(model)) + .map(({ id, name, release_date }) => ({ id, name, release_date })); + + list.sort((a, b) => { + const aDate = a.release_date ? Date.parse(a.release_date) : 0; + const bDate = b.release_date ? Date.parse(b.release_date) : 0; + return bDate - aDate; + }); + return list; +} + +export async function listOnboardingModels(): Promise<{ providers: ProviderSummary[]; lastUpdated?: string }> { + const { data, fetchedAt } = await getModelsDevData(); + const providers: ProviderSummary[] = []; + const flavors: Array<"openai" | "anthropic" | "google"> = ["openai", "anthropic", "google"]; + + for (const flavor of flavors) { + const provider = pickProvider(data, flavor); + if (!provider) continue; + providers.push({ + id: flavor, + name: provider.name, + models: normalizeModels(provider.models), + }); + } + + return { providers, lastUpdated: fetchedAt }; +} diff --git a/apps/x/packages/core/src/models/models.ts b/apps/x/packages/core/src/models/models.ts index d2d846e5..482931df 100644 --- a/apps/x/packages/core/src/models/models.ts +++ b/apps/x/packages/core/src/models/models.ts @@ -1,119 +1,87 @@ import { ProviderV2 } from "@ai-sdk/provider"; -import { createGateway } from "ai"; +import { createGateway, generateText } from "ai"; import { createOpenAI } from "@ai-sdk/openai"; import { createGoogleGenerativeAI } from "@ai-sdk/google"; import { createAnthropic } from "@ai-sdk/anthropic"; import { createOllama } from "ollama-ai-provider-v2"; import { createOpenRouter } from '@openrouter/ai-sdk-provider'; import { createOpenAICompatible } from '@ai-sdk/openai-compatible'; -import { IModelConfigRepo } from "./repo.js"; -import container from "../di/container.js"; +import { LlmModelConfig, LlmProvider } from "@x/shared/dist/models.js"; import z from "zod"; -export const Flavor = z.enum([ - "rowboat [free]", - "aigateway", - "anthropic", - "google", - "ollama", - "openai", - "openai-compatible", - "openrouter", -]); +export const Provider = LlmProvider; +export const ModelConfig = LlmModelConfig; -export const Provider = z.object({ - flavor: Flavor, - apiKey: z.string().optional(), - baseURL: z.string().optional(), - headers: z.record(z.string(), z.string()).optional(), -}); - -export const ModelConfig = z.object({ - providers: z.record(z.string(), Provider), - defaults: z.object({ - provider: z.string(), - model: z.string(), - }), -}); - -const providerMap: Record = {}; - -export async function getProvider(name: string = ""): Promise { - // get model conf - const repo = container.resolve("modelConfigRepo"); - const modelConfig = await repo.getConfig(); - if (!modelConfig) { - throw new Error("Model config not found"); - } - if (!name) { - name = modelConfig.defaults.provider; - } - if (providerMap[name]) { - return providerMap[name]; - } - const providerConfig = modelConfig.providers[name]; - if (!providerConfig) { - throw new Error(`Provider ${name} not found`); - } - const { apiKey, baseURL, headers } = providerConfig; - switch (providerConfig.flavor) { - case "rowboat [free]": - providerMap[name] = createGateway({ - apiKey: "rowboatx", - baseURL: "https://ai-gateway.rowboatlabs.com/v1/ai", - }); - break; - case "openai": - providerMap[name] = createOpenAI({ +export function createProvider(config: z.infer): ProviderV2 { + const { apiKey, baseURL, headers } = config; + switch (config.flavor) { + case "openai": + return createOpenAI({ apiKey, baseURL, headers, }); - break; case "aigateway": - providerMap[name] = createGateway({ + return createGateway({ apiKey, baseURL, - headers - }); - break; - case "anthropic": - providerMap[name] = createAnthropic({ - apiKey, - baseURL, - headers - }); - break; - case "google": - providerMap[name] = createGoogleGenerativeAI({ - apiKey, - baseURL, - headers - }); - break; - case "ollama": - providerMap[name] = createOllama({ - baseURL, - headers - }); - break; - case "openai-compatible": - providerMap[name] = createOpenAICompatible({ - name, - apiKey, - baseURL : baseURL || "", headers, }); - break; - case "openrouter": - providerMap[name] = createOpenRouter({ + case "anthropic": + return createAnthropic({ apiKey, baseURL, - headers + headers, + }); + case "google": + return createGoogleGenerativeAI({ + apiKey, + baseURL, + headers, + }); + case "ollama": + return createOllama({ + baseURL, + headers, + }); + case "openai-compatible": + return createOpenAICompatible({ + name: "openai-compatible", + apiKey, + baseURL: baseURL || "", + headers, + }); + case "openrouter": + return createOpenRouter({ + apiKey, + baseURL, + headers, }); - break; default: - throw new Error(`Provider ${name} not found`); + throw new Error(`Unsupported provider flavor: ${config.flavor}`); } - return providerMap[name]; -} \ No newline at end of file +} + +export async function testModelConnection( + providerConfig: z.infer, + model: string, + timeoutMs: number = 8000, +): Promise<{ success: boolean; error?: string }> { + const controller = new AbortController(); + const timeout = setTimeout(() => controller.abort(), timeoutMs); + try { + const provider = createProvider(providerConfig); + const languageModel = provider.languageModel(model); + await generateText({ + model: languageModel, + prompt: "ping", + abortSignal: controller.signal, + }); + return { success: true }; + } catch (error) { + const message = error instanceof Error ? error.message : "Connection test failed"; + return { success: false, error: message }; + } finally { + clearTimeout(timeout); + } +} diff --git a/apps/x/packages/core/src/models/repo.ts b/apps/x/packages/core/src/models/repo.ts index 33ad2502..f39aadf1 100644 --- a/apps/x/packages/core/src/models/repo.ts +++ b/apps/x/packages/core/src/models/repo.ts @@ -1,4 +1,4 @@ -import { ModelConfig, Provider } from "./models.js"; +import { ModelConfig } from "./models.js"; import { WorkDir } from "../config/config.js"; import fs from "fs/promises"; import path from "path"; @@ -7,21 +7,14 @@ import z from "zod"; export interface IModelConfigRepo { ensureConfig(): Promise; getConfig(): Promise>; - upsert(providerName: string, config: z.infer): Promise; - delete(providerName: string): Promise; - setDefault(providerName: string, model: string): Promise; + setConfig(config: z.infer): Promise; } const defaultConfig: z.infer = { - providers: { - "rowboat": { - flavor: "rowboat [free]", - } + provider: { + flavor: "openai", }, - defaults: { - provider: "rowboat", - model: "gpt-5.1", - } + model: "gpt-4.1", }; export class FSModelConfigRepo implements IModelConfigRepo { @@ -40,28 +33,7 @@ export class FSModelConfigRepo implements IModelConfigRepo { return ModelConfig.parse(JSON.parse(config)); } - private async setConfig(config: z.infer): Promise { + async setConfig(config: z.infer): Promise { await fs.writeFile(this.configPath, JSON.stringify(config, null, 2)); } - - async upsert(providerName: string, config: z.infer): Promise { - const conf = await this.getConfig(); - conf.providers[providerName] = config; - await this.setConfig(conf); - } - - async delete(providerName: string): Promise { - const conf = await this.getConfig(); - delete conf.providers[providerName]; - await this.setConfig(conf); - } - - async setDefault(providerName: string, model: string): Promise { - const conf = await this.getConfig(); - conf.defaults = { - provider: providerName, - model, - }; - await this.setConfig(conf); - } -} \ No newline at end of file +} diff --git a/apps/x/packages/shared/src/index.ts b/apps/x/packages/shared/src/index.ts index 878c1043..3bca8969 100644 --- a/apps/x/packages/shared/src/index.ts +++ b/apps/x/packages/shared/src/index.ts @@ -1,6 +1,7 @@ import { PrefixLogger } from './prefix-logger.js'; export * as ipc from './ipc.js'; +export * as models from './models.js'; export * as workspace from './workspace.js'; export * as mcp from './mcp.js'; -export { PrefixLogger }; \ No newline at end of file +export { PrefixLogger }; diff --git a/apps/x/packages/shared/src/ipc.ts b/apps/x/packages/shared/src/ipc.ts index aca2ce17..2835a90b 100644 --- a/apps/x/packages/shared/src/ipc.ts +++ b/apps/x/packages/shared/src/ipc.ts @@ -2,6 +2,7 @@ import { z } from 'zod'; import { RelPath, Encoding, Stat, DirEntry, ReaddirOptions, ReadFileResult, WorkspaceChangeEvent, WriteFileOptions, WriteFileResult, RemoveOptions } from './workspace.js'; import { ListToolsResponse } from './mcp.js'; import { AskHumanResponsePayload, CreateRunOptions, Run, ListRunsResponse, ToolPermissionAuthorizePayload } from './runs.js'; +import { LlmModelConfig } from './models.js'; // ============================================================================ // Runtime Validation Schemas (Single Source of Truth) @@ -173,6 +174,34 @@ const ipcSchemas = { req: z.null(), res: z.null(), }, + 'models:list': { + req: z.null(), + res: z.object({ + providers: z.array(z.object({ + id: z.string(), + name: z.string(), + models: z.array(z.object({ + id: z.string(), + name: z.string().optional(), + release_date: z.string().optional(), + })), + })), + lastUpdated: z.string().optional(), + }), + }, + 'models:test': { + req: LlmModelConfig, + res: z.object({ + success: z.boolean(), + error: z.string().optional(), + }), + }, + 'models:saveConfig': { + req: LlmModelConfig, + res: z.object({ + success: z.literal(true), + }), + }, 'oauth:connect': { req: z.object({ provider: z.string(), @@ -373,4 +402,4 @@ export function validateResponse( ): IPCChannels[K]['res'] { const schema = ipcSchemas[channel].res; return schema.parse(data) as IPCChannels[K]['res']; -} \ No newline at end of file +} diff --git a/apps/x/packages/shared/src/models.ts b/apps/x/packages/shared/src/models.ts new file mode 100644 index 00000000..14e91689 --- /dev/null +++ b/apps/x/packages/shared/src/models.ts @@ -0,0 +1,13 @@ +import { z } from "zod"; + +export const LlmProvider = z.object({ + flavor: z.enum(["openai", "anthropic", "google", "openrouter", "aigateway", "ollama", "openai-compatible"]), + apiKey: z.string().optional(), + baseURL: z.string().optional(), + headers: z.record(z.string(), z.string()).optional(), +}); + +export const LlmModelConfig = z.object({ + provider: LlmProvider, + model: z.string(), +});