diff --git a/apps/cli/src/application/assistant/chat.ts b/apps/cli/src/application/assistant/chat.ts index 15969d29..bc1a77c3 100644 --- a/apps/cli/src/application/assistant/chat.ts +++ b/apps/cli/src/application/assistant/chat.ts @@ -10,7 +10,7 @@ import { StreamableHTTPClientTransport } from "@modelcontextprotocol/sdk/client/ import { SSEClientTransport } from "@modelcontextprotocol/sdk/client/sse.js"; import { Client } from "@modelcontextprotocol/sdk/client/index.js"; import { getProvider } from "../lib/models.js"; -import { DefaultModel } from "../config/config.js"; +import { ModelConfig } from "../config/config.js"; const rl = readline.createInterface({ input, output }); @@ -59,7 +59,7 @@ export async function startCopilot() { let currentStep = 0; const provider = getProvider(); const result = streamText({ - model: provider(DefaultModel), + model: provider(ModelConfig.defaults.model), messages: messages, system: `You are an intelligent workflow assistant helping users manage their workflows in ${BASE_DIR}. diff --git a/apps/cli/src/application/config/config.ts b/apps/cli/src/application/config/config.ts index 12533ce3..f28a03fb 100644 --- a/apps/cli/src/application/config/config.ts +++ b/apps/cli/src/application/config/config.ts @@ -1,7 +1,7 @@ import path from "path"; import fs from "fs"; import { McpServerConfig } from "../entities/mcp.js"; -import { ModelConfig } from "../entities/models.js"; +import { ModelConfig as ModelConfigT } from "../entities/models.js"; import { z } from "zod"; import { homedir } from "os"; @@ -26,7 +26,7 @@ const baseMcpConfig: z.infer = { } }; -const baseModelConfig: z.infer = { +const baseModelConfig: z.infer = { providers: { openai: { flavor: "openai", @@ -71,16 +71,13 @@ function loadMcpServerConfig(): z.infer { return McpServerConfig.parse(JSON.parse(config)); } -function loadModelConfig(): z.infer { +function loadModelConfig(): z.infer { const configPath = path.join(WorkDir, "config", "models.json"); if (!fs.existsSync(configPath)) return baseModelConfig; const config = fs.readFileSync(configPath, "utf8"); - return ModelConfig.parse(JSON.parse(config)); + return ModelConfigT.parse(JSON.parse(config)); } const { mcpServers } = loadMcpServerConfig(); -const { providers, defaults } = loadModelConfig(); export const McpServers = mcpServers; -export const Providers = providers; -export const DefaultModel = defaults.model; -export const DefaultProvider = defaults.provider; +export const ModelConfig = loadModelConfig(); diff --git a/apps/cli/src/application/lib/agent.ts b/apps/cli/src/application/lib/agent.ts index 7a1d09a9..e43d2427 100644 --- a/apps/cli/src/application/lib/agent.ts +++ b/apps/cli/src/application/lib/agent.ts @@ -3,7 +3,7 @@ import { z } from "zod"; import { Step, StepInputT, StepOutputT } from "./step.js"; import { ModelMessage, stepCountIs, streamText, tool, Tool, ToolSet, jsonSchema } from "ai"; import { Agent, AgentTool } from "../entities/agent.js"; -import { DefaultModel, WorkDir } from "../config/config.js"; +import { ModelConfig, WorkDir } from "../config/config.js"; import fs from "fs"; import path from "path"; import { loadWorkflow } from "./utils.js"; @@ -158,7 +158,7 @@ export class AgentNode implements Step { const provider = getProvider(this.agent.provider); const { fullStream } = streamText({ - model: provider(this.agent.model || DefaultModel), + model: provider(this.agent.model || ModelConfig.defaults.model), messages: convertFromMessages(input), system: this.agent.instructions, stopWhen: stepCountIs(1), diff --git a/apps/cli/src/application/lib/models.ts b/apps/cli/src/application/lib/models.ts index 74a1b36d..de35dba3 100644 --- a/apps/cli/src/application/lib/models.ts +++ b/apps/cli/src/application/lib/models.ts @@ -1,18 +1,18 @@ import { createOpenAI, OpenAIProvider } from "@ai-sdk/openai"; import { createGoogleGenerativeAI, GoogleGenerativeAIProvider } from "@ai-sdk/google"; import { AnthropicProvider, createAnthropic } from "@ai-sdk/anthropic"; -import { DefaultModel, DefaultProvider, Providers } from "../config/config.js"; +import { ModelConfig } from "../config/config.js"; const providerMap: Record = {}; export function getProvider(name: string = "") { if (!name) { - name = DefaultProvider; + name = ModelConfig.defaults.provider; } if (providerMap[name]) { return providerMap[name]; } - const providerConfig = Providers[name]; + const providerConfig = ModelConfig.providers[name]; if (!providerConfig) { throw new Error(`Provider ${name} not found`); }