fix rb ai gateway integration

This commit is contained in:
Ramnique Singh 2025-12-01 11:59:13 +05:30
parent 8feb4f1425
commit ab617e63b5
9 changed files with 165 additions and 81 deletions

View file

@ -1,9 +1,9 @@
import { jsonSchema, ModelMessage } from "ai";
import { jsonSchema, ModelMessage, modelMessageSchema } from "ai";
import fs from "fs";
import path from "path";
import { getModelConfig, WorkDir } from "../config/config.js";
import { Agent, ToolAttachment } from "../entities/agent.js";
import { AssistantContentPart, AssistantMessage, Message, MessageList, ToolCallPart, ToolMessage, UserMessage } from "../entities/message.js";
import { AssistantContentPart, AssistantMessage, Message, MessageList, ProviderOptions, ToolCallPart, ToolMessage, UserMessage } from "../entities/message.js";
import { runIdGenerator } from "./run-id-gen.js";
import { LanguageModel, stepCountIs, streamText, tool, Tool, ToolSet } from "ai";
import { z } from "zod";
@ -90,6 +90,7 @@ export class StreamStepMessageBuilder {
private parts: z.infer<typeof AssistantContentPart>[] = [];
private textBuffer: string = "";
private reasoningBuffer: string = "";
private providerOptions: z.infer<typeof ProviderOptions> | undefined = undefined;
flushBuffers() {
// skip reasoning
@ -123,8 +124,12 @@ export class StreamStepMessageBuilder {
toolCallId: event.toolCallId,
toolName: event.toolName,
arguments: event.input,
providerOptions: event.providerOptions,
});
break;
case "finish-step":
this.providerOptions = event.providerOptions;
break;
}
}
@ -133,6 +138,7 @@ export class StreamStepMessageBuilder {
return {
role: "assistant",
content: this.parts,
providerOptions: this.providerOptions,
};
}
}
@ -173,12 +179,14 @@ export async function loadAgent(id: string): Promise<z.infer<typeof Agent>> {
export function convertFromMessages(messages: z.infer<typeof Message>[]): ModelMessage[] {
const result: ModelMessage[] = [];
for (const msg of messages) {
const { providerOptions } = msg;
switch (msg.role) {
case "assistant":
if (typeof msg.content === 'string') {
result.push({
role: "assistant",
content: msg.content,
providerOptions,
});
} else {
result.push({
@ -195,9 +203,11 @@ export function convertFromMessages(messages: z.infer<typeof Message>[]): ModelM
toolCallId: part.toolCallId,
toolName: part.toolName,
input: part.arguments,
providerOptions: part.providerOptions,
};
}
}),
providerOptions,
});
}
break;
@ -205,12 +215,14 @@ export function convertFromMessages(messages: z.infer<typeof Message>[]): ModelM
result.push({
role: "system",
content: msg.content,
providerOptions,
});
break;
case "user":
result.push({
role: "user",
content: msg.content,
providerOptions,
});
break;
case "tool":
@ -227,11 +239,13 @@ export function convertFromMessages(messages: z.infer<typeof Message>[]): ModelM
},
},
],
providerOptions,
});
break;
}
}
return result;
// doing this because: https://github.com/OpenRouterTeam/ai-sdk-provider/issues/262
return JSON.parse(JSON.stringify(result));
}
async function buildTools(agent: z.infer<typeof Agent>): Promise<ToolSet> {
@ -446,7 +460,7 @@ export async function* streamAgent(state: AgentState): AsyncGenerator<z.infer<ty
}
// if tool has been denied, deny
if (state.deniedToolCallIds[toolCallId]) {
if (state.deniedToolCallIds[toolCallId]) {
yield* state.ingestAndLogAndYield({
type: "message",
message: {
@ -561,7 +575,7 @@ export async function* streamAgent(state: AgentState): AsyncGenerator<z.infer<ty
if (underlyingTool.type === "builtin" && underlyingTool.name === "executeCommand") {
// if command is blocked, then seek permission
if (isBlocked(part.arguments.command)) {
yield *state.ingestAndLogAndYield({
yield* state.ingestAndLogAndYield({
type: "tool-permission-request",
toolCall: part,
subflow: [],
@ -609,28 +623,33 @@ async function* streamLlm(
case "reasoning-start":
yield {
type: "reasoning-start",
providerOptions: event.providerMetadata,
};
break;
case "reasoning-delta":
yield {
type: "reasoning-delta",
delta: event.text,
providerOptions: event.providerMetadata,
};
break;
case "reasoning-end":
yield {
type: "reasoning-end",
providerOptions: event.providerMetadata,
};
break;
case "text-start":
yield {
type: "text-start",
providerOptions: event.providerMetadata,
};
break;
case "text-delta":
yield {
type: "text-delta",
delta: event.text,
providerOptions: event.providerMetadata,
};
break;
case "tool-call":
@ -639,12 +658,15 @@ async function* streamLlm(
toolCallId: event.toolCallId,
toolName: event.toolName,
input: event.input,
providerOptions: event.providerMetadata,
};
break;
case "finish":
case "finish-step":
yield {
type: "usage",
usage: event.totalUsage,
type: "finish-step",
usage: event.usage,
finishReason: event.finishReason,
providerOptions: event.providerMetadata,
};
break;
default: