build model selection

This commit is contained in:
Ramnique Singh 2025-11-20 16:41:41 +05:30
parent add897e448
commit 92004033de
8 changed files with 357 additions and 66 deletions

View file

@ -1,7 +1,7 @@
import { jsonSchema, ModelMessage } from "ai";
import fs from "fs";
import path from "path";
import { ModelConfig, WorkDir } from "../config/config.js";
import { getModelConfig, WorkDir } from "../config/config.js";
import { Agent, ToolAttachment } from "../entities/agent.js";
import { AssistantContentPart, AssistantMessage, Message, MessageList, ToolCallPart, ToolMessage, UserMessage } from "../entities/message.js";
import { runIdGenerator } from "./run-id-gen.js";
@ -405,6 +405,12 @@ export class AgentState {
}
export async function* streamAgent(state: AgentState): AsyncGenerator<z.infer<typeof RunEvent>, void, unknown> {
// get model config
const modelConfig = await getModelConfig();
if (!modelConfig) {
throw new Error("Model config not found");
}
// set up agent
const agent = await loadAgent(state.agentName);
@ -412,8 +418,8 @@ export async function* streamAgent(state: AgentState): AsyncGenerator<z.infer<ty
const tools = await buildTools(agent);
// set up provider + model
const provider = getProvider(agent.provider);
const model = provider(agent.model || ModelConfig.defaults.model);
const provider = await getProvider(agent.provider);
const model = provider.languageModel(agent.model || modelConfig.defaults.model);
let loopCounter = 0;
while (true) {

View file

@ -1,50 +1,76 @@
import { createOpenAI, OpenAIProvider } from "@ai-sdk/openai";
import { createGoogleGenerativeAI, GoogleGenerativeAIProvider } from "@ai-sdk/google";
import { AnthropicProvider, createAnthropic } from "@ai-sdk/anthropic";
import { OllamaProvider, createOllama } from "ollama-ai-provider-v2";
import { ModelConfig } from "../config/config.js";
import { ProviderV2 } from "@ai-sdk/provider";
import { createOpenAI } from "@ai-sdk/openai";
import { createGoogleGenerativeAI } from "@ai-sdk/google";
import { createAnthropic } from "@ai-sdk/anthropic";
import { createOllama } from "ollama-ai-provider-v2";
import { createOpenRouter } from '@openrouter/ai-sdk-provider';
import { createOpenAICompatible } from '@ai-sdk/openai-compatible';
import { getModelConfig } from "../config/config.js";
const providerMap: Record<string, OpenAIProvider | GoogleGenerativeAIProvider | AnthropicProvider | OllamaProvider> = {};
const providerMap: Record<string, ProviderV2> = {};
export function getProvider(name: string = "") {
export async function getProvider(name: string = ""): Promise<ProviderV2> {
// get model conf
const modelConfig = await getModelConfig();
if (!modelConfig) {
throw new Error("Model config not found");
}
if (!name) {
name = ModelConfig.defaults.provider;
name = modelConfig.defaults.provider;
}
if (providerMap[name]) {
return providerMap[name];
}
const providerConfig = ModelConfig.providers[name];
const providerConfig = modelConfig.providers[name];
if (!providerConfig) {
throw new Error(`Provider ${name} not found`);
}
const { apiKey, baseURL, headers } = providerConfig;
switch (providerConfig.flavor) {
case "openai":
providerMap[name] = createOpenAI({
apiKey: providerConfig.apiKey,
baseURL: providerConfig.baseURL,
headers: providerConfig.headers,
apiKey,
baseURL,
headers,
});
break;
case "anthropic":
providerMap[name] = createAnthropic({
apiKey: providerConfig.apiKey,
baseURL: providerConfig.baseURL,
headers: providerConfig.headers,
apiKey,
baseURL,
headers
});
break;
case "google":
providerMap[name] = createGoogleGenerativeAI({
apiKey: providerConfig.apiKey,
baseURL: providerConfig.baseURL,
headers: providerConfig.headers,
apiKey,
baseURL,
headers
});
break;
case "ollama":
providerMap[name] = createOllama({
baseURL: providerConfig.baseURL,
headers: providerConfig.headers,
baseURL,
headers
});
break;
case "openai-compatible":
providerMap[name] = createOpenAICompatible({
name,
apiKey,
baseURL : baseURL || "",
headers
});
break;
case "openrouter":
providerMap[name] = createOpenRouter({
apiKey,
baseURL,
headers
});
break;
default:
throw new Error(`Provider ${name} not found`);
}
return providerMap[name];
}