build model selection

This commit is contained in:
Ramnique Singh 2025-11-20 16:41:41 +05:30
parent add897e448
commit 92004033de
8 changed files with 357 additions and 66 deletions

View file

@ -1,13 +1,15 @@
import path from "path";
import fs from "fs";
import { McpServerConfig } from "../entities/mcp.js";
import { ModelConfig as ModelConfigT } from "../entities/models.js";
import { ModelConfig } from "../entities/models.js";
import { z } from "zod";
import { homedir } from "os";
// Resolve app root relative to compiled file location (dist/...)
export const WorkDir = path.join(homedir(), ".rowboat");
let modelConfig: z.infer<typeof ModelConfig> | null = null;
const baseMcpConfig: z.infer<typeof McpServerConfig> = {
mcpServers: {
firecrawl: {
@ -26,27 +28,6 @@ const baseMcpConfig: z.infer<typeof McpServerConfig> = {
}
};
const baseModelConfig: z.infer<typeof ModelConfigT> = {
providers: {
openai: {
flavor: "openai",
},
anthropic: {
flavor: "anthropic",
},
google: {
flavor: "google",
},
ollama: {
flavor: "ollama",
}
},
defaults: {
provider: "openai",
model: "gpt-5.1",
}
};
function ensureMcpConfig() {
const configPath = path.join(WorkDir, "config", "mcp.json");
if (!fs.existsSync(configPath)) {
@ -54,24 +35,14 @@ function ensureMcpConfig() {
}
}
function ensureModelConfig() {
const configPath = path.join(WorkDir, "config", "models.json");
if (!fs.existsSync(configPath)) {
fs.writeFileSync(configPath, JSON.stringify(baseModelConfig, null, 2));
}
}
function ensureDirs() {
const ensure = (p: string) => { if (!fs.existsSync(p)) fs.mkdirSync(p, { recursive: true }); };
ensure(WorkDir);
ensure(path.join(WorkDir, "agents"));
ensure(path.join(WorkDir, "config"));
ensureMcpConfig();
ensureModelConfig();
}
ensureDirs();
function loadMcpServerConfig(): z.infer<typeof McpServerConfig> {
const configPath = path.join(WorkDir, "config", "mcp.json");
if (!fs.existsSync(configPath)) return { mcpServers: {} };
@ -79,13 +50,27 @@ function loadMcpServerConfig(): z.infer<typeof McpServerConfig> {
return McpServerConfig.parse(JSON.parse(config));
}
function loadModelConfig(): z.infer<typeof ModelConfigT> {
export async function getModelConfig(): Promise<z.infer<typeof ModelConfig> | null> {
if (modelConfig) {
return modelConfig;
}
const configPath = path.join(WorkDir, "config", "models.json");
if (!fs.existsSync(configPath)) return baseModelConfig;
const config = fs.readFileSync(configPath, "utf8");
return ModelConfigT.parse(JSON.parse(config));
try {
const config = await fs.promises.readFile(configPath, "utf8");
modelConfig = ModelConfig.parse(JSON.parse(config));
return modelConfig;
} catch (error) {
console.error(`Warning! model config not found!`);
return null;
}
}
export async function updateModelConfig(config: z.infer<typeof ModelConfig>) {
modelConfig = config;
const configPath = path.join(WorkDir, "config", "models.json");
await fs.promises.writeFile(configPath, JSON.stringify(config, null, 2));
}
ensureDirs();
const { mcpServers } = loadMcpServerConfig();
export const McpServers = mcpServers;
export const ModelConfig = loadModelConfig();
export const McpServers = mcpServers;

View file

@ -1,7 +1,14 @@
import z from "zod";
export const Provider = z.object({
flavor: z.enum(["openai", "anthropic", "google", "ollama"]),
flavor: z.enum([
"anthropic",
"google",
"ollama",
"openai",
"openai-compatible",
"openrouter",
]),
apiKey: z.string().optional(),
baseURL: z.string().optional(),
headers: z.record(z.string(), z.string()).optional(),

View file

@ -1,7 +1,7 @@
import { jsonSchema, ModelMessage } from "ai";
import fs from "fs";
import path from "path";
import { ModelConfig, WorkDir } from "../config/config.js";
import { getModelConfig, WorkDir } from "../config/config.js";
import { Agent, ToolAttachment } from "../entities/agent.js";
import { AssistantContentPart, AssistantMessage, Message, MessageList, ToolCallPart, ToolMessage, UserMessage } from "../entities/message.js";
import { runIdGenerator } from "./run-id-gen.js";
@ -405,6 +405,12 @@ export class AgentState {
}
export async function* streamAgent(state: AgentState): AsyncGenerator<z.infer<typeof RunEvent>, void, unknown> {
// get model config
const modelConfig = await getModelConfig();
if (!modelConfig) {
throw new Error("Model config not found");
}
// set up agent
const agent = await loadAgent(state.agentName);
@ -412,8 +418,8 @@ export async function* streamAgent(state: AgentState): AsyncGenerator<z.infer<ty
const tools = await buildTools(agent);
// set up provider + model
const provider = getProvider(agent.provider);
const model = provider(agent.model || ModelConfig.defaults.model);
const provider = await getProvider(agent.provider);
const model = provider.languageModel(agent.model || modelConfig.defaults.model);
let loopCounter = 0;
while (true) {

View file

@ -1,50 +1,76 @@
import { createOpenAI, OpenAIProvider } from "@ai-sdk/openai";
import { createGoogleGenerativeAI, GoogleGenerativeAIProvider } from "@ai-sdk/google";
import { AnthropicProvider, createAnthropic } from "@ai-sdk/anthropic";
import { OllamaProvider, createOllama } from "ollama-ai-provider-v2";
import { ModelConfig } from "../config/config.js";
import { ProviderV2 } from "@ai-sdk/provider";
import { createOpenAI } from "@ai-sdk/openai";
import { createGoogleGenerativeAI } from "@ai-sdk/google";
import { createAnthropic } from "@ai-sdk/anthropic";
import { createOllama } from "ollama-ai-provider-v2";
import { createOpenRouter } from '@openrouter/ai-sdk-provider';
import { createOpenAICompatible } from '@ai-sdk/openai-compatible';
import { getModelConfig } from "../config/config.js";
const providerMap: Record<string, OpenAIProvider | GoogleGenerativeAIProvider | AnthropicProvider | OllamaProvider> = {};
const providerMap: Record<string, ProviderV2> = {};
export function getProvider(name: string = "") {
export async function getProvider(name: string = ""): Promise<ProviderV2> {
// get model conf
const modelConfig = await getModelConfig();
if (!modelConfig) {
throw new Error("Model config not found");
}
if (!name) {
name = ModelConfig.defaults.provider;
name = modelConfig.defaults.provider;
}
if (providerMap[name]) {
return providerMap[name];
}
const providerConfig = ModelConfig.providers[name];
const providerConfig = modelConfig.providers[name];
if (!providerConfig) {
throw new Error(`Provider ${name} not found`);
}
const { apiKey, baseURL, headers } = providerConfig;
switch (providerConfig.flavor) {
case "openai":
providerMap[name] = createOpenAI({
apiKey: providerConfig.apiKey,
baseURL: providerConfig.baseURL,
headers: providerConfig.headers,
apiKey,
baseURL,
headers,
});
break;
case "anthropic":
providerMap[name] = createAnthropic({
apiKey: providerConfig.apiKey,
baseURL: providerConfig.baseURL,
headers: providerConfig.headers,
apiKey,
baseURL,
headers
});
break;
case "google":
providerMap[name] = createGoogleGenerativeAI({
apiKey: providerConfig.apiKey,
baseURL: providerConfig.baseURL,
headers: providerConfig.headers,
apiKey,
baseURL,
headers
});
break;
case "ollama":
providerMap[name] = createOllama({
baseURL: providerConfig.baseURL,
headers: providerConfig.headers,
baseURL,
headers
});
break;
case "openai-compatible":
providerMap[name] = createOpenAICompatible({
name,
apiKey,
baseURL : baseURL || "",
headers
});
break;
case "openrouter":
providerMap[name] = createOpenRouter({
apiKey,
baseURL,
headers
});
break;
default:
throw new Error(`Provider ${name} not found`);
}
return providerMap[name];
}