refactor model / provider code

This commit is contained in:
Ramnique Singh 2025-11-14 09:23:37 +05:30
parent 61924d0b01
commit fb355ec10d
4 changed files with 12 additions and 15 deletions

View file

@ -10,7 +10,7 @@ import { StreamableHTTPClientTransport } from "@modelcontextprotocol/sdk/client/
import { SSEClientTransport } from "@modelcontextprotocol/sdk/client/sse.js";
import { Client } from "@modelcontextprotocol/sdk/client/index.js";
import { getProvider } from "../lib/models.js";
import { DefaultModel } from "../config/config.js";
import { ModelConfig } from "../config/config.js";
const rl = readline.createInterface({ input, output });
@ -59,7 +59,7 @@ export async function startCopilot() {
let currentStep = 0;
const provider = getProvider();
const result = streamText({
model: provider(DefaultModel),
model: provider(ModelConfig.defaults.model),
messages: messages,
system: `You are an intelligent workflow assistant helping users manage their workflows in ${BASE_DIR}.

View file

@ -1,7 +1,7 @@
import path from "path";
import fs from "fs";
import { McpServerConfig } from "../entities/mcp.js";
import { ModelConfig } from "../entities/models.js";
import { ModelConfig as ModelConfigT } from "../entities/models.js";
import { z } from "zod";
import { homedir } from "os";
@ -26,7 +26,7 @@ const baseMcpConfig: z.infer<typeof McpServerConfig> = {
}
};
const baseModelConfig: z.infer<typeof ModelConfig> = {
const baseModelConfig: z.infer<typeof ModelConfigT> = {
providers: {
openai: {
flavor: "openai",
@ -71,16 +71,13 @@ function loadMcpServerConfig(): z.infer<typeof McpServerConfig> {
return McpServerConfig.parse(JSON.parse(config));
}
function loadModelConfig(): z.infer<typeof ModelConfig> {
function loadModelConfig(): z.infer<typeof ModelConfigT> {
const configPath = path.join(WorkDir, "config", "models.json");
if (!fs.existsSync(configPath)) return baseModelConfig;
const config = fs.readFileSync(configPath, "utf8");
return ModelConfig.parse(JSON.parse(config));
return ModelConfigT.parse(JSON.parse(config));
}
const { mcpServers } = loadMcpServerConfig();
const { providers, defaults } = loadModelConfig();
export const McpServers = mcpServers;
export const Providers = providers;
export const DefaultModel = defaults.model;
export const DefaultProvider = defaults.provider;
export const ModelConfig = loadModelConfig();

View file

@ -3,7 +3,7 @@ import { z } from "zod";
import { Step, StepInputT, StepOutputT } from "./step.js";
import { ModelMessage, stepCountIs, streamText, tool, Tool, ToolSet, jsonSchema } from "ai";
import { Agent, AgentTool } from "../entities/agent.js";
import { DefaultModel, WorkDir } from "../config/config.js";
import { ModelConfig, WorkDir } from "../config/config.js";
import fs from "fs";
import path from "path";
import { loadWorkflow } from "./utils.js";
@ -158,7 +158,7 @@ export class AgentNode implements Step {
const provider = getProvider(this.agent.provider);
const { fullStream } = streamText({
model: provider(this.agent.model || DefaultModel),
model: provider(this.agent.model || ModelConfig.defaults.model),
messages: convertFromMessages(input),
system: this.agent.instructions,
stopWhen: stepCountIs(1),

View file

@ -1,18 +1,18 @@
import { createOpenAI, OpenAIProvider } from "@ai-sdk/openai";
import { createGoogleGenerativeAI, GoogleGenerativeAIProvider } from "@ai-sdk/google";
import { AnthropicProvider, createAnthropic } from "@ai-sdk/anthropic";
import { DefaultModel, DefaultProvider, Providers } from "../config/config.js";
import { ModelConfig } from "../config/config.js";
const providerMap: Record<string, OpenAIProvider | GoogleGenerativeAIProvider | AnthropicProvider> = {};
export function getProvider(name: string = "") {
if (!name) {
name = DefaultProvider;
name = ModelConfig.defaults.provider;
}
if (providerMap[name]) {
return providerMap[name];
}
const providerConfig = Providers[name];
const providerConfig = ModelConfig.providers[name];
if (!providerConfig) {
throw new Error(`Provider ${name} not found`);
}