feat: simplify LLM config and onboarding

This commit is contained in:
Ramnique Singh 2026-02-04 01:12:06 +05:30
parent 948c6e7176
commit 10f94ce67e
10 changed files with 630 additions and 153 deletions

View file

@ -15,7 +15,7 @@ import { CopilotAgent } from "../application/assistant/agent.js";
import { isBlocked } from "../application/lib/command-executor.js";
import container from "../di/container.js";
import { IModelConfigRepo } from "../models/repo.js";
import { getProvider } from "../models/models.js";
import { createProvider } from "../models/models.js";
import { IAgentsRepo } from "./repo.js";
import { IMonotonicallyIncreasingIdGenerator } from "../application/lib/id-gen.js";
import { IBus } from "../application/lib/bus.js";
@ -623,8 +623,8 @@ export async function* streamAgent({
const tools = await buildTools(agent);
// set up provider + model
const provider = await getProvider(agent.provider);
const model = provider.languageModel(agent.model || modelConfig.defaults.model);
const provider = createProvider(modelConfig.provider);
const model = provider.languageModel(modelConfig.model);
let loopCounter = 0;
while (true) {

View file

@ -0,0 +1,174 @@
import fs from "node:fs/promises";
import path from "node:path";
import z from "zod";
import { WorkDir } from "../config/config.js";
const CACHE_PATH = path.join(WorkDir, "config", "models.dev.json");
const CACHE_TTL_MS = 24 * 60 * 60 * 1000;
const ModelsDevModel = z.object({
id: z.string().optional(),
name: z.string().optional(),
release_date: z.string().optional(),
tool_call: z.boolean().optional(),
experimental: z.boolean().optional(),
status: z.enum(["alpha", "beta", "deprecated"]).optional(),
}).passthrough();
const ModelsDevProvider = z.object({
id: z.string().optional(),
name: z.string(),
models: z.record(z.string(), ModelsDevModel),
}).passthrough();
const ModelsDevResponse = z.record(z.string(), ModelsDevProvider);
type ProviderSummary = {
id: string;
name: string;
models: Array<{
id: string;
name?: string;
release_date?: string;
}>;
};
type CacheFile = {
fetchedAt: string;
data: unknown;
};
async function readCache(): Promise<CacheFile | null> {
try {
const raw = await fs.readFile(CACHE_PATH, "utf8");
return JSON.parse(raw) as CacheFile;
} catch {
return null;
}
}
async function writeCache(data: unknown): Promise<void> {
const payload: CacheFile = {
fetchedAt: new Date().toISOString(),
data,
};
await fs.writeFile(CACHE_PATH, JSON.stringify(payload, null, 2));
}
async function fetchModelsDev(): Promise<unknown> {
const response = await fetch("https://models.dev/api.json", {
headers: { "User-Agent": "Rowboat" },
});
if (!response.ok) {
throw new Error(`models.dev fetch failed: ${response.status}`);
}
return response.json();
}
function isCacheFresh(fetchedAt: string): boolean {
const age = Date.now() - new Date(fetchedAt).getTime();
return age < CACHE_TTL_MS;
}
async function getModelsDevData(): Promise<{ data: z.infer<typeof ModelsDevResponse>; fetchedAt?: string }> {
const cached = await readCache();
if (cached?.fetchedAt && isCacheFresh(cached.fetchedAt)) {
const parsed = ModelsDevResponse.safeParse(cached.data);
if (parsed.success) {
return { data: parsed.data, fetchedAt: cached.fetchedAt };
}
}
try {
const fresh = await fetchModelsDev();
const parsed = ModelsDevResponse.parse(fresh);
await writeCache(parsed);
return { data: parsed, fetchedAt: new Date().toISOString() };
} catch (error) {
if (cached) {
const parsed = ModelsDevResponse.safeParse(cached.data);
if (parsed.success) {
return { data: parsed.data, fetchedAt: cached.fetchedAt };
}
}
throw error;
}
}
function scoreProvider(flavor: string, id: string, name: string): number {
const normalizedId = id.toLowerCase();
const normalizedName = name.toLowerCase();
let score = 0;
if (normalizedId === flavor) score += 100;
if (normalizedName.includes(flavor)) score += 20;
if (flavor === "google") {
if (normalizedName.includes("gemini")) score += 10;
if (normalizedName.includes("vertex")) score -= 5;
}
return score;
}
function pickProvider(
data: z.infer<typeof ModelsDevResponse>,
flavor: "openai" | "anthropic" | "google",
): z.infer<typeof ModelsDevProvider> | null {
if (data[flavor]) return data[flavor];
let best: { score: number; provider: z.infer<typeof ModelsDevProvider> } | null = null;
for (const [id, provider] of Object.entries(data)) {
const s = scoreProvider(flavor, id, provider.name);
if (s <= 0) continue;
if (!best || s > best.score) {
best = { score: s, provider };
}
}
return best?.provider ?? null;
}
function isStableModel(model: z.infer<typeof ModelsDevModel>): boolean {
if (model.experimental) return false;
if (model.status && ["alpha", "beta", "deprecated"].includes(model.status)) return false;
return true;
}
function supportsToolCall(model: z.infer<typeof ModelsDevModel>): boolean {
return model.tool_call === true;
}
function normalizeModels(models: Record<string, z.infer<typeof ModelsDevModel>>): ProviderSummary["models"] {
const list = Object.entries(models)
.map(([id, model]) => ({
id: model.id ?? id,
name: model.name,
release_date: model.release_date,
tool_call: model.tool_call,
experimental: model.experimental,
status: model.status,
}))
.filter((model) => isStableModel(model) && supportsToolCall(model))
.map(({ id, name, release_date }) => ({ id, name, release_date }));
list.sort((a, b) => {
const aDate = a.release_date ? Date.parse(a.release_date) : 0;
const bDate = b.release_date ? Date.parse(b.release_date) : 0;
return bDate - aDate;
});
return list;
}
export async function listOnboardingModels(): Promise<{ providers: ProviderSummary[]; lastUpdated?: string }> {
const { data, fetchedAt } = await getModelsDevData();
const providers: ProviderSummary[] = [];
const flavors: Array<"openai" | "anthropic" | "google"> = ["openai", "anthropic", "google"];
for (const flavor of flavors) {
const provider = pickProvider(data, flavor);
if (!provider) continue;
providers.push({
id: flavor,
name: provider.name,
models: normalizeModels(provider.models),
});
}
return { providers, lastUpdated: fetchedAt };
}

View file

@ -1,119 +1,87 @@
import { ProviderV2 } from "@ai-sdk/provider";
import { createGateway } from "ai";
import { createGateway, generateText } from "ai";
import { createOpenAI } from "@ai-sdk/openai";
import { createGoogleGenerativeAI } from "@ai-sdk/google";
import { createAnthropic } from "@ai-sdk/anthropic";
import { createOllama } from "ollama-ai-provider-v2";
import { createOpenRouter } from '@openrouter/ai-sdk-provider';
import { createOpenAICompatible } from '@ai-sdk/openai-compatible';
import { IModelConfigRepo } from "./repo.js";
import container from "../di/container.js";
import { LlmModelConfig, LlmProvider } from "@x/shared/dist/models.js";
import z from "zod";
export const Flavor = z.enum([
"rowboat [free]",
"aigateway",
"anthropic",
"google",
"ollama",
"openai",
"openai-compatible",
"openrouter",
]);
export const Provider = LlmProvider;
export const ModelConfig = LlmModelConfig;
export const Provider = z.object({
flavor: Flavor,
apiKey: z.string().optional(),
baseURL: z.string().optional(),
headers: z.record(z.string(), z.string()).optional(),
});
export const ModelConfig = z.object({
providers: z.record(z.string(), Provider),
defaults: z.object({
provider: z.string(),
model: z.string(),
}),
});
const providerMap: Record<string, ProviderV2> = {};
export async function getProvider(name: string = ""): Promise<ProviderV2> {
// get model conf
const repo = container.resolve<IModelConfigRepo>("modelConfigRepo");
const modelConfig = await repo.getConfig();
if (!modelConfig) {
throw new Error("Model config not found");
}
if (!name) {
name = modelConfig.defaults.provider;
}
if (providerMap[name]) {
return providerMap[name];
}
const providerConfig = modelConfig.providers[name];
if (!providerConfig) {
throw new Error(`Provider ${name} not found`);
}
const { apiKey, baseURL, headers } = providerConfig;
switch (providerConfig.flavor) {
case "rowboat [free]":
providerMap[name] = createGateway({
apiKey: "rowboatx",
baseURL: "https://ai-gateway.rowboatlabs.com/v1/ai",
});
break;
case "openai":
providerMap[name] = createOpenAI({
export function createProvider(config: z.infer<typeof Provider>): ProviderV2 {
const { apiKey, baseURL, headers } = config;
switch (config.flavor) {
case "openai":
return createOpenAI({
apiKey,
baseURL,
headers,
});
break;
case "aigateway":
providerMap[name] = createGateway({
return createGateway({
apiKey,
baseURL,
headers
});
break;
case "anthropic":
providerMap[name] = createAnthropic({
apiKey,
baseURL,
headers
});
break;
case "google":
providerMap[name] = createGoogleGenerativeAI({
apiKey,
baseURL,
headers
});
break;
case "ollama":
providerMap[name] = createOllama({
baseURL,
headers
});
break;
case "openai-compatible":
providerMap[name] = createOpenAICompatible({
name,
apiKey,
baseURL : baseURL || "",
headers,
});
break;
case "openrouter":
providerMap[name] = createOpenRouter({
case "anthropic":
return createAnthropic({
apiKey,
baseURL,
headers
headers,
});
case "google":
return createGoogleGenerativeAI({
apiKey,
baseURL,
headers,
});
case "ollama":
return createOllama({
baseURL,
headers,
});
case "openai-compatible":
return createOpenAICompatible({
name: "openai-compatible",
apiKey,
baseURL: baseURL || "",
headers,
});
case "openrouter":
return createOpenRouter({
apiKey,
baseURL,
headers,
});
break;
default:
throw new Error(`Provider ${name} not found`);
throw new Error(`Unsupported provider flavor: ${config.flavor}`);
}
return providerMap[name];
}
}
export async function testModelConnection(
providerConfig: z.infer<typeof Provider>,
model: string,
timeoutMs: number = 8000,
): Promise<{ success: boolean; error?: string }> {
const controller = new AbortController();
const timeout = setTimeout(() => controller.abort(), timeoutMs);
try {
const provider = createProvider(providerConfig);
const languageModel = provider.languageModel(model);
await generateText({
model: languageModel,
prompt: "ping",
abortSignal: controller.signal,
});
return { success: true };
} catch (error) {
const message = error instanceof Error ? error.message : "Connection test failed";
return { success: false, error: message };
} finally {
clearTimeout(timeout);
}
}

View file

@ -1,4 +1,4 @@
import { ModelConfig, Provider } from "./models.js";
import { ModelConfig } from "./models.js";
import { WorkDir } from "../config/config.js";
import fs from "fs/promises";
import path from "path";
@ -7,21 +7,14 @@ import z from "zod";
export interface IModelConfigRepo {
ensureConfig(): Promise<void>;
getConfig(): Promise<z.infer<typeof ModelConfig>>;
upsert(providerName: string, config: z.infer<typeof Provider>): Promise<void>;
delete(providerName: string): Promise<void>;
setDefault(providerName: string, model: string): Promise<void>;
setConfig(config: z.infer<typeof ModelConfig>): Promise<void>;
}
const defaultConfig: z.infer<typeof ModelConfig> = {
providers: {
"rowboat": {
flavor: "rowboat [free]",
}
provider: {
flavor: "openai",
},
defaults: {
provider: "rowboat",
model: "gpt-5.1",
}
model: "gpt-4.1",
};
export class FSModelConfigRepo implements IModelConfigRepo {
@ -40,28 +33,7 @@ export class FSModelConfigRepo implements IModelConfigRepo {
return ModelConfig.parse(JSON.parse(config));
}
private async setConfig(config: z.infer<typeof ModelConfig>): Promise<void> {
async setConfig(config: z.infer<typeof ModelConfig>): Promise<void> {
await fs.writeFile(this.configPath, JSON.stringify(config, null, 2));
}
async upsert(providerName: string, config: z.infer<typeof Provider>): Promise<void> {
const conf = await this.getConfig();
conf.providers[providerName] = config;
await this.setConfig(conf);
}
async delete(providerName: string): Promise<void> {
const conf = await this.getConfig();
delete conf.providers[providerName];
await this.setConfig(conf);
}
async setDefault(providerName: string, model: string): Promise<void> {
const conf = await this.getConfig();
conf.defaults = {
provider: providerName,
model,
};
await this.setConfig(conf);
}
}
}