fix rb ai gateway integration

This commit is contained in:
Ramnique Singh 2025-12-01 11:59:13 +05:30
parent 8feb4f1425
commit ab617e63b5
9 changed files with 165 additions and 81 deletions

View file

@ -14,6 +14,7 @@ import { Example } from "./application/entities/example.js";
import { z } from "zod";
import { Flavor } from "./application/entities/models.js";
import { examples } from "./examples/index.js";
import { modelMessageSchema } from "ai";
export async function updateState(agent: string, runId: string) {
const state = new AgentState(agent, runId);
@ -225,6 +226,7 @@ export async function modelConfig() {
const defaultApiKeyEnvVars: Record<z.infer<typeof Flavor>, string> = {
"rowboat [free]": "",
openai: "OPENAI_API_KEY",
aigateway: "AI_GATEWAY_API_KEY",
anthropic: "ANTHROPIC_API_KEY",
google: "GOOGLE_GENERATIVE_AI_API_KEY",
ollama: "",
@ -234,6 +236,7 @@ export async function modelConfig() {
const defaultBaseUrls: Record<z.infer<typeof Flavor>, string> = {
"rowboat [free]": "",
openai: "https://api.openai.com/v1",
aigateway: "https://ai-gateway.vercel.sh/v1/ai",
anthropic: "https://api.anthropic.com/v1",
google: "https://generativelanguage.googleapis.com/v1beta",
ollama: "http://localhost:11434",
@ -243,6 +246,7 @@ export async function modelConfig() {
const defaultModels: Record<z.infer<typeof Flavor>, string> = {
"rowboat [free]": "google/gemini-3-pro-preview",
openai: "gpt-5.1",
aigateway: "gpt-5.1",
anthropic: "claude-sonnet-4-5",
google: "gemini-2.5-pro",
ollama: "llama3.1",

View file

@ -1,40 +1,46 @@
import { z } from "zod";
import { ProviderOptions } from "./message.js";
export const LlmStepStreamReasoningStartEvent = z.object({
const BaseEvent = z.object({
providerOptions: ProviderOptions.optional(),
})
export const LlmStepStreamReasoningStartEvent = BaseEvent.extend({
type: z.literal("reasoning-start"),
});
export const LlmStepStreamReasoningDeltaEvent = z.object({
export const LlmStepStreamReasoningDeltaEvent = BaseEvent.extend({
type: z.literal("reasoning-delta"),
delta: z.string(),
});
export const LlmStepStreamReasoningEndEvent = z.object({
export const LlmStepStreamReasoningEndEvent = BaseEvent.extend({
type: z.literal("reasoning-end"),
});
export const LlmStepStreamTextStartEvent = z.object({
export const LlmStepStreamTextStartEvent = BaseEvent.extend({
type: z.literal("text-start"),
});
export const LlmStepStreamTextDeltaEvent = z.object({
export const LlmStepStreamTextDeltaEvent = BaseEvent.extend({
type: z.literal("text-delta"),
delta: z.string(),
});
export const LlmStepStreamTextEndEvent = z.object({
export const LlmStepStreamTextEndEvent = BaseEvent.extend({
type: z.literal("text-end"),
});
export const LlmStepStreamToolCallEvent = z.object({
export const LlmStepStreamToolCallEvent = BaseEvent.extend({
type: z.literal("tool-call"),
toolCallId: z.string(),
toolName: z.string(),
input: z.any(),
});
export const LlmStepStreamUsageEvent = z.object({
type: z.literal("usage"),
export const LlmStepStreamFinishStepEvent = z.object({
type: z.literal("finish-step"),
finishReason: z.enum(["stop", "tool-calls", "length", "content-filter", "error", "other", "unknown"]),
usage: z.object({
inputTokens: z.number().optional(),
outputTokens: z.number().optional(),
@ -42,6 +48,7 @@ export const LlmStepStreamUsageEvent = z.object({
reasoningTokens: z.number().optional(),
cachedInputTokens: z.number().optional(),
}),
providerOptions: ProviderOptions.optional(),
});
export const LlmStepStreamEvent = z.union([
@ -52,5 +59,5 @@ export const LlmStepStreamEvent = z.union([
LlmStepStreamTextDeltaEvent,
LlmStepStreamTextEndEvent,
LlmStepStreamToolCallEvent,
LlmStepStreamUsageEvent,
LlmStepStreamFinishStepEvent,
]);

View file

@ -1,13 +1,17 @@
import { z } from "zod";
export const ProviderOptions = z.record(z.string(), z.record(z.string(), z.json()));
export const TextPart = z.object({
type: z.literal("text"),
text: z.string(),
providerOptions: ProviderOptions.optional(),
});
export const ReasoningPart = z.object({
type: z.literal("reasoning"),
text: z.string(),
providerOptions: ProviderOptions.optional(),
});
export const ToolCallPart = z.object({
@ -15,6 +19,7 @@ export const ToolCallPart = z.object({
toolCallId: z.string(),
toolName: z.string(),
arguments: z.any(),
providerOptions: ProviderOptions.optional(),
});
export const AssistantContentPart = z.union([
@ -26,6 +31,7 @@ export const AssistantContentPart = z.union([
export const UserMessage = z.object({
role: z.literal("user"),
content: z.string(),
providerOptions: ProviderOptions.optional(),
});
export const AssistantMessage = z.object({
@ -34,11 +40,13 @@ export const AssistantMessage = z.object({
z.string(),
z.array(AssistantContentPart),
]),
providerOptions: ProviderOptions.optional(),
});
export const SystemMessage = z.object({
role: z.literal("system"),
content: z.string(),
providerOptions: ProviderOptions.optional(),
});
export const ToolMessage = z.object({
@ -46,6 +54,7 @@ export const ToolMessage = z.object({
content: z.string(),
toolCallId: z.string(),
toolName: z.string(),
providerOptions: ProviderOptions.optional(),
});
export const Message = z.discriminatedUnion("role", [

View file

@ -2,6 +2,7 @@ import z from "zod";
export const Flavor = z.enum([
"rowboat [free]",
"aigateway",
"anthropic",
"google",
"ollama",

View file

@ -1,9 +1,9 @@
import { jsonSchema, ModelMessage } from "ai";
import { jsonSchema, ModelMessage, modelMessageSchema } from "ai";
import fs from "fs";
import path from "path";
import { getModelConfig, WorkDir } from "../config/config.js";
import { Agent, ToolAttachment } from "../entities/agent.js";
import { AssistantContentPart, AssistantMessage, Message, MessageList, ToolCallPart, ToolMessage, UserMessage } from "../entities/message.js";
import { AssistantContentPart, AssistantMessage, Message, MessageList, ProviderOptions, ToolCallPart, ToolMessage, UserMessage } from "../entities/message.js";
import { runIdGenerator } from "./run-id-gen.js";
import { LanguageModel, stepCountIs, streamText, tool, Tool, ToolSet } from "ai";
import { z } from "zod";
@ -90,6 +90,7 @@ export class StreamStepMessageBuilder {
private parts: z.infer<typeof AssistantContentPart>[] = [];
private textBuffer: string = "";
private reasoningBuffer: string = "";
private providerOptions: z.infer<typeof ProviderOptions> | undefined = undefined;
flushBuffers() {
// skip reasoning
@ -123,8 +124,12 @@ export class StreamStepMessageBuilder {
toolCallId: event.toolCallId,
toolName: event.toolName,
arguments: event.input,
providerOptions: event.providerOptions,
});
break;
case "finish-step":
this.providerOptions = event.providerOptions;
break;
}
}
@ -133,6 +138,7 @@ export class StreamStepMessageBuilder {
return {
role: "assistant",
content: this.parts,
providerOptions: this.providerOptions,
};
}
}
@ -173,12 +179,14 @@ export async function loadAgent(id: string): Promise<z.infer<typeof Agent>> {
export function convertFromMessages(messages: z.infer<typeof Message>[]): ModelMessage[] {
const result: ModelMessage[] = [];
for (const msg of messages) {
const { providerOptions } = msg;
switch (msg.role) {
case "assistant":
if (typeof msg.content === 'string') {
result.push({
role: "assistant",
content: msg.content,
providerOptions,
});
} else {
result.push({
@ -195,9 +203,11 @@ export function convertFromMessages(messages: z.infer<typeof Message>[]): ModelM
toolCallId: part.toolCallId,
toolName: part.toolName,
input: part.arguments,
providerOptions: part.providerOptions,
};
}
}),
providerOptions,
});
}
break;
@ -205,12 +215,14 @@ export function convertFromMessages(messages: z.infer<typeof Message>[]): ModelM
result.push({
role: "system",
content: msg.content,
providerOptions,
});
break;
case "user":
result.push({
role: "user",
content: msg.content,
providerOptions,
});
break;
case "tool":
@ -227,11 +239,13 @@ export function convertFromMessages(messages: z.infer<typeof Message>[]): ModelM
},
},
],
providerOptions,
});
break;
}
}
return result;
// doing this because: https://github.com/OpenRouterTeam/ai-sdk-provider/issues/262
return JSON.parse(JSON.stringify(result));
}
async function buildTools(agent: z.infer<typeof Agent>): Promise<ToolSet> {
@ -446,7 +460,7 @@ export async function* streamAgent(state: AgentState): AsyncGenerator<z.infer<ty
}
// if tool has been denied, deny
if (state.deniedToolCallIds[toolCallId]) {
if (state.deniedToolCallIds[toolCallId]) {
yield* state.ingestAndLogAndYield({
type: "message",
message: {
@ -561,7 +575,7 @@ export async function* streamAgent(state: AgentState): AsyncGenerator<z.infer<ty
if (underlyingTool.type === "builtin" && underlyingTool.name === "executeCommand") {
// if command is blocked, then seek permission
if (isBlocked(part.arguments.command)) {
yield *state.ingestAndLogAndYield({
yield* state.ingestAndLogAndYield({
type: "tool-permission-request",
toolCall: part,
subflow: [],
@ -609,28 +623,33 @@ async function* streamLlm(
case "reasoning-start":
yield {
type: "reasoning-start",
providerOptions: event.providerMetadata,
};
break;
case "reasoning-delta":
yield {
type: "reasoning-delta",
delta: event.text,
providerOptions: event.providerMetadata,
};
break;
case "reasoning-end":
yield {
type: "reasoning-end",
providerOptions: event.providerMetadata,
};
break;
case "text-start":
yield {
type: "text-start",
providerOptions: event.providerMetadata,
};
break;
case "text-delta":
yield {
type: "text-delta",
delta: event.text,
providerOptions: event.providerMetadata,
};
break;
case "tool-call":
@ -639,12 +658,15 @@ async function* streamLlm(
toolCallId: event.toolCallId,
toolName: event.toolName,
input: event.input,
providerOptions: event.providerMetadata,
};
break;
case "finish":
case "finish-step":
yield {
type: "usage",
usage: event.totalUsage,
type: "finish-step",
usage: event.usage,
finishReason: event.finishReason,
providerOptions: event.providerMetadata,
};
break;
default:

View file

@ -1,4 +1,5 @@
import { ProviderV2 } from "@ai-sdk/provider";
import { createGateway } from "ai";
import { createOpenAI } from "@ai-sdk/openai";
import { createGoogleGenerativeAI } from "@ai-sdk/google";
import { createAnthropic } from "@ai-sdk/anthropic";
@ -28,9 +29,9 @@ export async function getProvider(name: string = ""): Promise<ProviderV2> {
const { apiKey, baseURL, headers } = providerConfig;
switch (providerConfig.flavor) {
case "rowboat [free]":
providerMap[name] = createOpenAICompatible({
name: "rowboat [free]",
baseURL: "https://ai-gateway.rowboatlabs.com/v1",
providerMap[name] = createGateway({
apiKey: "rowboatx",
baseURL: "https://ai-gateway.rowboatlabs.com/v1/ai",
});
break;
case "openai":
@ -40,6 +41,13 @@ export async function getProvider(name: string = ""): Promise<ProviderV2> {
headers,
});
break;
case "aigateway":
providerMap[name] = createGateway({
apiKey,
baseURL,
headers
});
break;
case "anthropic":
providerMap[name] = createAnthropic({
apiKey,
@ -65,7 +73,7 @@ export async function getProvider(name: string = ""): Promise<ProviderV2> {
name,
apiKey,
baseURL : baseURL || "",
headers
headers,
});
break;
case "openrouter":

View file

@ -77,8 +77,8 @@ export class StreamRenderer {
case "tool-call":
this.onToolCall(event.toolCallId, event.toolName, event.input);
break;
case "usage":
this.onUsage(event.usage);
case "finish-step":
this.onFinishStep(event.finishReason, event.usage);
break;
}
}
@ -219,13 +219,15 @@ export class StreamRenderer {
this.write("\n");
}
private onUsage(usage: {
inputTokens?: number;
outputTokens?: number;
totalTokens?: number;
reasoningTokens?: number;
cachedInputTokens?: number;
}) {
private onFinishStep(
finishReason: "stop" | "tool-calls" | "length" | "content-filter" | "error" | "other" | "unknown",
usage: {
inputTokens?: number;
outputTokens?: number;
totalTokens?: number;
reasoningTokens?: number;
cachedInputTokens?: number;
}) {
const parts: string[] = [];
if (usage.inputTokens !== undefined) parts.push(`${this.dim("in:")} ${usage.inputTokens}`);
if (usage.outputTokens !== undefined) parts.push(`${this.dim("out:")} ${usage.outputTokens}`);
@ -234,8 +236,13 @@ export class StreamRenderer {
if (usage.totalTokens !== undefined) parts.push(`${this.dim("total:")} ${this.bold(usage.totalTokens.toString())}`);
const line = parts.join(this.dim(" | "));
this.write("\n");
this.write(this.dim("╭─ Usage\n"));
this.write(this.dim("│ ") + line);
this.write(this.bold("╭─ ") + this.bold("Finish"));
this.write("\n");
this.write(this.dim("│ ") + this.dim("reason: ") + finishReason);
if (line.length) {
this.write("\n");
this.write(this.dim("│ ") + line);
}
this.write("\n");
this.write(this.dim("╰─────────────\n"));
}