mirror of
https://github.com/rowboatlabs/rowboat.git
synced 2026-04-25 08:26:22 +02:00
build model selection
This commit is contained in:
parent
add897e448
commit
92004033de
8 changed files with 357 additions and 66 deletions
|
|
@ -3,7 +3,7 @@ import { StreamRenderer } from "./application/lib/stream-renderer.js";
|
|||
import { stdin as input, stdout as output } from "node:process";
|
||||
import fs from "fs";
|
||||
import path from "path";
|
||||
import { WorkDir } from "./application/config/config.js";
|
||||
import { WorkDir, getModelConfig, updateModelConfig } from "./application/config/config.js";
|
||||
import { RunEvent } from "./application/entities/run-events.js";
|
||||
import { createInterface, Interface } from "node:readline/promises";
|
||||
import { ToolCallPart } from "./application/entities/message.js";
|
||||
|
|
@ -54,6 +54,12 @@ export async function app(opts: {
|
|||
input?: string;
|
||||
noInteractive?: boolean;
|
||||
}) {
|
||||
// check if model config is required
|
||||
const c = await getModelConfig();
|
||||
if (!c) {
|
||||
await modelConfig();
|
||||
}
|
||||
|
||||
const renderer = new StreamRenderer();
|
||||
const state = new AgentState(opts.agent, opts.runId);
|
||||
|
||||
|
|
@ -202,4 +208,177 @@ async function getUserInput(
|
|||
process.exit(0);
|
||||
}
|
||||
return input;
|
||||
}
|
||||
|
||||
export async function modelConfig() {
|
||||
// load existing model config
|
||||
const config = await getModelConfig();
|
||||
|
||||
const rl = createInterface({ input, output });
|
||||
try {
|
||||
const flavors = [
|
||||
"openai",
|
||||
"anthropic",
|
||||
"google",
|
||||
"ollama",
|
||||
"openai-compatible",
|
||||
"openrouter",
|
||||
] as const;
|
||||
const defaultBaseUrls: Record<(typeof flavors)[number], string> = {
|
||||
openai: "https://api.openai.com/v1",
|
||||
anthropic: "https://api.anthropic.com/v1",
|
||||
google: "https://generativelanguage.googleapis.com",
|
||||
ollama: "http://localhost:11434",
|
||||
"openai-compatible": "http://localhost:8080/v1",
|
||||
openrouter: "https://openrouter.ai/api/v1",
|
||||
};
|
||||
const defaultModels: Record<(typeof flavors)[number], string> = {
|
||||
openai: "gpt-5.1",
|
||||
anthropic: "claude-3.5-sonnet",
|
||||
google: "gemini-1.5-pro",
|
||||
ollama: "llama3.1",
|
||||
"openai-compatible": "gpt-4o",
|
||||
openrouter: "openrouter/auto",
|
||||
};
|
||||
|
||||
const currentProvider = config?.defaults?.provider;
|
||||
const currentModel = config?.defaults?.model;
|
||||
const currentProviderConfig = currentProvider ? config?.providers?.[currentProvider] : undefined;
|
||||
if (config) {
|
||||
console.log("Currently using:");
|
||||
console.log(`- provider: ${currentProvider || "none"}${currentProviderConfig?.flavor ? ` (${currentProviderConfig.flavor})` : ""}`);
|
||||
console.log(`- model: ${currentModel || "none"}`);
|
||||
console.log("");
|
||||
}
|
||||
|
||||
const flavorPromptLines = flavors
|
||||
.map((f, idx) => ` ${idx + 1}. ${f}`)
|
||||
.join("\n");
|
||||
const flavorAnswer = await rl.question(
|
||||
`Select a provider type:\n${flavorPromptLines}\nEnter number or name` +
|
||||
(currentProvider ? ` [${currentProvider}]` : "") +
|
||||
": ",
|
||||
);
|
||||
let selectedFlavorRaw = flavorAnswer.trim();
|
||||
let selectedFlavor: (typeof flavors)[number] | null = null;
|
||||
if (selectedFlavorRaw === "" && currentProvider && (flavors as readonly string[]).includes(currentProvider)) {
|
||||
selectedFlavor = currentProvider as (typeof flavors)[number];
|
||||
} else if (/^\d+$/.test(selectedFlavorRaw)) {
|
||||
const idx = parseInt(selectedFlavorRaw, 10) - 1;
|
||||
if (idx >= 0 && idx < flavors.length) {
|
||||
selectedFlavor = flavors[idx];
|
||||
}
|
||||
} else if ((flavors as readonly string[]).includes(selectedFlavorRaw)) {
|
||||
selectedFlavor = selectedFlavorRaw as (typeof flavors)[number];
|
||||
}
|
||||
if (!selectedFlavor) {
|
||||
console.error("Invalid selection. Exiting.");
|
||||
return;
|
||||
}
|
||||
|
||||
const existingAliases = Object.keys(config?.providers || {}).filter(
|
||||
(name) => config?.providers?.[name]?.flavor === selectedFlavor,
|
||||
);
|
||||
let providerName: string | null = null;
|
||||
let chooseMode: "existing" | "add" = "add";
|
||||
if (existingAliases.length > 0) {
|
||||
const listLines = existingAliases
|
||||
.map((alias, idx) => ` ${idx + 1}. use existing: ${alias}`)
|
||||
.join("\n");
|
||||
const addIndex = existingAliases.length + 1;
|
||||
const providerSelect = await rl.question(
|
||||
`Found existing providers for ${selectedFlavor}:\n${listLines}\n ${addIndex}. add new\nEnter number or name/alias [${addIndex}]: `,
|
||||
);
|
||||
const sel = providerSelect.trim();
|
||||
if (sel === "" || sel.toLowerCase() === "add" || sel.toLowerCase() === "new") {
|
||||
chooseMode = "add";
|
||||
} else if (/^\d+$/.test(sel)) {
|
||||
const idx = parseInt(sel, 10) - 1;
|
||||
if (idx >= 0 && idx < existingAliases.length) {
|
||||
providerName = existingAliases[idx];
|
||||
chooseMode = "existing";
|
||||
} else if (idx === existingAliases.length) {
|
||||
chooseMode = "add";
|
||||
} else {
|
||||
console.error("Invalid selection. Exiting.");
|
||||
return;
|
||||
}
|
||||
} else if (existingAliases.includes(sel)) {
|
||||
providerName = sel;
|
||||
chooseMode = "existing";
|
||||
} else {
|
||||
console.error("Invalid selection. Exiting.");
|
||||
return;
|
||||
}
|
||||
}
|
||||
if (chooseMode === "existing" && !providerName) {
|
||||
console.error("No provider selected. Exiting.");
|
||||
return;
|
||||
}
|
||||
|
||||
if (chooseMode === "existing") {
|
||||
const modelDefault =
|
||||
currentProvider === providerName && currentModel
|
||||
? currentModel
|
||||
: defaultModels[selectedFlavor];
|
||||
const modelAns = await rl.question(
|
||||
`Specify model for ${selectedFlavor} [${modelDefault}]: `,
|
||||
);
|
||||
const model = modelAns.trim() || modelDefault;
|
||||
|
||||
const newConfig = {
|
||||
providers: { ...(config?.providers || {}) },
|
||||
defaults: {
|
||||
provider: providerName!,
|
||||
model,
|
||||
},
|
||||
};
|
||||
await updateModelConfig(newConfig as any);
|
||||
console.log(`Model configuration updated. Provider set to '${providerName}'.`);
|
||||
return;
|
||||
}
|
||||
|
||||
const providerNameAns = await rl.question(
|
||||
`Enter a name/alias for this provider [${selectedFlavor}]: `,
|
||||
);
|
||||
providerName = providerNameAns.trim() || selectedFlavor;
|
||||
|
||||
const baseUrlDefault = defaultBaseUrls[selectedFlavor] || "";
|
||||
const baseUrlAns = await rl.question(
|
||||
`Enter baseURL for ${selectedFlavor} [${baseUrlDefault}]: `,
|
||||
);
|
||||
const baseURL = (baseUrlAns.trim() || baseUrlDefault) || undefined;
|
||||
|
||||
const apiKeyAns = await rl.question(
|
||||
`Enter API key for ${selectedFlavor} (leave blank to skip): `,
|
||||
);
|
||||
const apiKey = apiKeyAns.trim() || undefined;
|
||||
|
||||
const modelDefault = defaultModels[selectedFlavor];
|
||||
const modelAns = await rl.question(
|
||||
`Specify model for ${selectedFlavor} [${modelDefault}]: `,
|
||||
);
|
||||
const model = modelAns.trim() || modelDefault;
|
||||
|
||||
const mergedProviders = {
|
||||
...(config?.providers || {}),
|
||||
[providerName]: {
|
||||
flavor: selectedFlavor,
|
||||
...(apiKey ? { apiKey } : {}),
|
||||
...(baseURL ? { baseURL } : {}),
|
||||
},
|
||||
};
|
||||
const newConfig = {
|
||||
providers: mergedProviders,
|
||||
defaults: {
|
||||
provider: providerName,
|
||||
model,
|
||||
},
|
||||
};
|
||||
|
||||
await updateModelConfig(newConfig as any);
|
||||
console.log(`Model configuration updated. Provider '${providerName}' ${config?.providers?.[providerName] ? "overwritten" : "added"}.`);
|
||||
} finally {
|
||||
rl.close();
|
||||
}
|
||||
}
|
||||
|
|
@ -1,13 +1,15 @@
|
|||
import path from "path";
|
||||
import fs from "fs";
|
||||
import { McpServerConfig } from "../entities/mcp.js";
|
||||
import { ModelConfig as ModelConfigT } from "../entities/models.js";
|
||||
import { ModelConfig } from "../entities/models.js";
|
||||
import { z } from "zod";
|
||||
import { homedir } from "os";
|
||||
|
||||
// Resolve app root relative to compiled file location (dist/...)
|
||||
export const WorkDir = path.join(homedir(), ".rowboat");
|
||||
|
||||
let modelConfig: z.infer<typeof ModelConfig> | null = null;
|
||||
|
||||
const baseMcpConfig: z.infer<typeof McpServerConfig> = {
|
||||
mcpServers: {
|
||||
firecrawl: {
|
||||
|
|
@ -26,27 +28,6 @@ const baseMcpConfig: z.infer<typeof McpServerConfig> = {
|
|||
}
|
||||
};
|
||||
|
||||
const baseModelConfig: z.infer<typeof ModelConfigT> = {
|
||||
providers: {
|
||||
openai: {
|
||||
flavor: "openai",
|
||||
},
|
||||
anthropic: {
|
||||
flavor: "anthropic",
|
||||
},
|
||||
google: {
|
||||
flavor: "google",
|
||||
},
|
||||
ollama: {
|
||||
flavor: "ollama",
|
||||
}
|
||||
},
|
||||
defaults: {
|
||||
provider: "openai",
|
||||
model: "gpt-5.1",
|
||||
}
|
||||
};
|
||||
|
||||
function ensureMcpConfig() {
|
||||
const configPath = path.join(WorkDir, "config", "mcp.json");
|
||||
if (!fs.existsSync(configPath)) {
|
||||
|
|
@ -54,24 +35,14 @@ function ensureMcpConfig() {
|
|||
}
|
||||
}
|
||||
|
||||
function ensureModelConfig() {
|
||||
const configPath = path.join(WorkDir, "config", "models.json");
|
||||
if (!fs.existsSync(configPath)) {
|
||||
fs.writeFileSync(configPath, JSON.stringify(baseModelConfig, null, 2));
|
||||
}
|
||||
}
|
||||
|
||||
function ensureDirs() {
|
||||
const ensure = (p: string) => { if (!fs.existsSync(p)) fs.mkdirSync(p, { recursive: true }); };
|
||||
ensure(WorkDir);
|
||||
ensure(path.join(WorkDir, "agents"));
|
||||
ensure(path.join(WorkDir, "config"));
|
||||
ensureMcpConfig();
|
||||
ensureModelConfig();
|
||||
}
|
||||
|
||||
ensureDirs();
|
||||
|
||||
function loadMcpServerConfig(): z.infer<typeof McpServerConfig> {
|
||||
const configPath = path.join(WorkDir, "config", "mcp.json");
|
||||
if (!fs.existsSync(configPath)) return { mcpServers: {} };
|
||||
|
|
@ -79,13 +50,27 @@ function loadMcpServerConfig(): z.infer<typeof McpServerConfig> {
|
|||
return McpServerConfig.parse(JSON.parse(config));
|
||||
}
|
||||
|
||||
function loadModelConfig(): z.infer<typeof ModelConfigT> {
|
||||
export async function getModelConfig(): Promise<z.infer<typeof ModelConfig> | null> {
|
||||
if (modelConfig) {
|
||||
return modelConfig;
|
||||
}
|
||||
const configPath = path.join(WorkDir, "config", "models.json");
|
||||
if (!fs.existsSync(configPath)) return baseModelConfig;
|
||||
const config = fs.readFileSync(configPath, "utf8");
|
||||
return ModelConfigT.parse(JSON.parse(config));
|
||||
try {
|
||||
const config = await fs.promises.readFile(configPath, "utf8");
|
||||
modelConfig = ModelConfig.parse(JSON.parse(config));
|
||||
return modelConfig;
|
||||
} catch (error) {
|
||||
console.error(`Warning! model config not found!`);
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
export async function updateModelConfig(config: z.infer<typeof ModelConfig>) {
|
||||
modelConfig = config;
|
||||
const configPath = path.join(WorkDir, "config", "models.json");
|
||||
await fs.promises.writeFile(configPath, JSON.stringify(config, null, 2));
|
||||
}
|
||||
|
||||
ensureDirs();
|
||||
const { mcpServers } = loadMcpServerConfig();
|
||||
export const McpServers = mcpServers;
|
||||
export const ModelConfig = loadModelConfig();
|
||||
export const McpServers = mcpServers;
|
||||
|
|
@ -1,7 +1,14 @@
|
|||
import z from "zod";
|
||||
|
||||
export const Provider = z.object({
|
||||
flavor: z.enum(["openai", "anthropic", "google", "ollama"]),
|
||||
flavor: z.enum([
|
||||
"anthropic",
|
||||
"google",
|
||||
"ollama",
|
||||
"openai",
|
||||
"openai-compatible",
|
||||
"openrouter",
|
||||
]),
|
||||
apiKey: z.string().optional(),
|
||||
baseURL: z.string().optional(),
|
||||
headers: z.record(z.string(), z.string()).optional(),
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import { jsonSchema, ModelMessage } from "ai";
|
||||
import fs from "fs";
|
||||
import path from "path";
|
||||
import { ModelConfig, WorkDir } from "../config/config.js";
|
||||
import { getModelConfig, WorkDir } from "../config/config.js";
|
||||
import { Agent, ToolAttachment } from "../entities/agent.js";
|
||||
import { AssistantContentPart, AssistantMessage, Message, MessageList, ToolCallPart, ToolMessage, UserMessage } from "../entities/message.js";
|
||||
import { runIdGenerator } from "./run-id-gen.js";
|
||||
|
|
@ -405,6 +405,12 @@ export class AgentState {
|
|||
}
|
||||
|
||||
export async function* streamAgent(state: AgentState): AsyncGenerator<z.infer<typeof RunEvent>, void, unknown> {
|
||||
// get model config
|
||||
const modelConfig = await getModelConfig();
|
||||
if (!modelConfig) {
|
||||
throw new Error("Model config not found");
|
||||
}
|
||||
|
||||
// set up agent
|
||||
const agent = await loadAgent(state.agentName);
|
||||
|
||||
|
|
@ -412,8 +418,8 @@ export async function* streamAgent(state: AgentState): AsyncGenerator<z.infer<ty
|
|||
const tools = await buildTools(agent);
|
||||
|
||||
// set up provider + model
|
||||
const provider = getProvider(agent.provider);
|
||||
const model = provider(agent.model || ModelConfig.defaults.model);
|
||||
const provider = await getProvider(agent.provider);
|
||||
const model = provider.languageModel(agent.model || modelConfig.defaults.model);
|
||||
let loopCounter = 0;
|
||||
|
||||
while (true) {
|
||||
|
|
|
|||
|
|
@ -1,50 +1,76 @@
|
|||
import { createOpenAI, OpenAIProvider } from "@ai-sdk/openai";
|
||||
import { createGoogleGenerativeAI, GoogleGenerativeAIProvider } from "@ai-sdk/google";
|
||||
import { AnthropicProvider, createAnthropic } from "@ai-sdk/anthropic";
|
||||
import { OllamaProvider, createOllama } from "ollama-ai-provider-v2";
|
||||
import { ModelConfig } from "../config/config.js";
|
||||
import { ProviderV2 } from "@ai-sdk/provider";
|
||||
import { createOpenAI } from "@ai-sdk/openai";
|
||||
import { createGoogleGenerativeAI } from "@ai-sdk/google";
|
||||
import { createAnthropic } from "@ai-sdk/anthropic";
|
||||
import { createOllama } from "ollama-ai-provider-v2";
|
||||
import { createOpenRouter } from '@openrouter/ai-sdk-provider';
|
||||
import { createOpenAICompatible } from '@ai-sdk/openai-compatible';
|
||||
import { getModelConfig } from "../config/config.js";
|
||||
|
||||
const providerMap: Record<string, OpenAIProvider | GoogleGenerativeAIProvider | AnthropicProvider | OllamaProvider> = {};
|
||||
const providerMap: Record<string, ProviderV2> = {};
|
||||
|
||||
export function getProvider(name: string = "") {
|
||||
export async function getProvider(name: string = ""): Promise<ProviderV2> {
|
||||
// get model conf
|
||||
const modelConfig = await getModelConfig();
|
||||
if (!modelConfig) {
|
||||
throw new Error("Model config not found");
|
||||
}
|
||||
if (!name) {
|
||||
name = ModelConfig.defaults.provider;
|
||||
name = modelConfig.defaults.provider;
|
||||
}
|
||||
if (providerMap[name]) {
|
||||
return providerMap[name];
|
||||
}
|
||||
const providerConfig = ModelConfig.providers[name];
|
||||
const providerConfig = modelConfig.providers[name];
|
||||
if (!providerConfig) {
|
||||
throw new Error(`Provider ${name} not found`);
|
||||
}
|
||||
const { apiKey, baseURL, headers } = providerConfig;
|
||||
switch (providerConfig.flavor) {
|
||||
case "openai":
|
||||
providerMap[name] = createOpenAI({
|
||||
apiKey: providerConfig.apiKey,
|
||||
baseURL: providerConfig.baseURL,
|
||||
headers: providerConfig.headers,
|
||||
apiKey,
|
||||
baseURL,
|
||||
headers,
|
||||
});
|
||||
break;
|
||||
case "anthropic":
|
||||
providerMap[name] = createAnthropic({
|
||||
apiKey: providerConfig.apiKey,
|
||||
baseURL: providerConfig.baseURL,
|
||||
headers: providerConfig.headers,
|
||||
apiKey,
|
||||
baseURL,
|
||||
headers
|
||||
});
|
||||
break;
|
||||
case "google":
|
||||
providerMap[name] = createGoogleGenerativeAI({
|
||||
apiKey: providerConfig.apiKey,
|
||||
baseURL: providerConfig.baseURL,
|
||||
headers: providerConfig.headers,
|
||||
apiKey,
|
||||
baseURL,
|
||||
headers
|
||||
});
|
||||
break;
|
||||
case "ollama":
|
||||
providerMap[name] = createOllama({
|
||||
baseURL: providerConfig.baseURL,
|
||||
headers: providerConfig.headers,
|
||||
baseURL,
|
||||
headers
|
||||
});
|
||||
break;
|
||||
case "openai-compatible":
|
||||
providerMap[name] = createOpenAICompatible({
|
||||
name,
|
||||
apiKey,
|
||||
baseURL : baseURL || "",
|
||||
headers
|
||||
});
|
||||
break;
|
||||
case "openrouter":
|
||||
providerMap[name] = createOpenRouter({
|
||||
apiKey,
|
||||
baseURL,
|
||||
headers
|
||||
});
|
||||
break;
|
||||
default:
|
||||
throw new Error(`Provider ${name} not found`);
|
||||
}
|
||||
return providerMap[name];
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue