mirror of
https://github.com/rowboatlabs/rowboat.git
synced 2026-04-25 16:36:22 +02:00
fix rb ai gateway integration
This commit is contained in:
parent
8feb4f1425
commit
ab617e63b5
9 changed files with 165 additions and 81 deletions
|
|
@ -14,6 +14,7 @@ import { Example } from "./application/entities/example.js";
|
|||
import { z } from "zod";
|
||||
import { Flavor } from "./application/entities/models.js";
|
||||
import { examples } from "./examples/index.js";
|
||||
import { modelMessageSchema } from "ai";
|
||||
|
||||
export async function updateState(agent: string, runId: string) {
|
||||
const state = new AgentState(agent, runId);
|
||||
|
|
@ -225,6 +226,7 @@ export async function modelConfig() {
|
|||
const defaultApiKeyEnvVars: Record<z.infer<typeof Flavor>, string> = {
|
||||
"rowboat [free]": "",
|
||||
openai: "OPENAI_API_KEY",
|
||||
aigateway: "AI_GATEWAY_API_KEY",
|
||||
anthropic: "ANTHROPIC_API_KEY",
|
||||
google: "GOOGLE_GENERATIVE_AI_API_KEY",
|
||||
ollama: "",
|
||||
|
|
@ -234,6 +236,7 @@ export async function modelConfig() {
|
|||
const defaultBaseUrls: Record<z.infer<typeof Flavor>, string> = {
|
||||
"rowboat [free]": "",
|
||||
openai: "https://api.openai.com/v1",
|
||||
aigateway: "https://ai-gateway.vercel.sh/v1/ai",
|
||||
anthropic: "https://api.anthropic.com/v1",
|
||||
google: "https://generativelanguage.googleapis.com/v1beta",
|
||||
ollama: "http://localhost:11434",
|
||||
|
|
@ -243,6 +246,7 @@ export async function modelConfig() {
|
|||
const defaultModels: Record<z.infer<typeof Flavor>, string> = {
|
||||
"rowboat [free]": "google/gemini-3-pro-preview",
|
||||
openai: "gpt-5.1",
|
||||
aigateway: "gpt-5.1",
|
||||
anthropic: "claude-sonnet-4-5",
|
||||
google: "gemini-2.5-pro",
|
||||
ollama: "llama3.1",
|
||||
|
|
|
|||
|
|
@ -1,40 +1,46 @@
|
|||
import { z } from "zod";
|
||||
import { ProviderOptions } from "./message.js";
|
||||
|
||||
export const LlmStepStreamReasoningStartEvent = z.object({
|
||||
const BaseEvent = z.object({
|
||||
providerOptions: ProviderOptions.optional(),
|
||||
})
|
||||
|
||||
export const LlmStepStreamReasoningStartEvent = BaseEvent.extend({
|
||||
type: z.literal("reasoning-start"),
|
||||
});
|
||||
|
||||
export const LlmStepStreamReasoningDeltaEvent = z.object({
|
||||
export const LlmStepStreamReasoningDeltaEvent = BaseEvent.extend({
|
||||
type: z.literal("reasoning-delta"),
|
||||
delta: z.string(),
|
||||
});
|
||||
|
||||
export const LlmStepStreamReasoningEndEvent = z.object({
|
||||
export const LlmStepStreamReasoningEndEvent = BaseEvent.extend({
|
||||
type: z.literal("reasoning-end"),
|
||||
});
|
||||
|
||||
export const LlmStepStreamTextStartEvent = z.object({
|
||||
export const LlmStepStreamTextStartEvent = BaseEvent.extend({
|
||||
type: z.literal("text-start"),
|
||||
});
|
||||
|
||||
export const LlmStepStreamTextDeltaEvent = z.object({
|
||||
export const LlmStepStreamTextDeltaEvent = BaseEvent.extend({
|
||||
type: z.literal("text-delta"),
|
||||
delta: z.string(),
|
||||
});
|
||||
|
||||
export const LlmStepStreamTextEndEvent = z.object({
|
||||
export const LlmStepStreamTextEndEvent = BaseEvent.extend({
|
||||
type: z.literal("text-end"),
|
||||
});
|
||||
|
||||
export const LlmStepStreamToolCallEvent = z.object({
|
||||
export const LlmStepStreamToolCallEvent = BaseEvent.extend({
|
||||
type: z.literal("tool-call"),
|
||||
toolCallId: z.string(),
|
||||
toolName: z.string(),
|
||||
input: z.any(),
|
||||
});
|
||||
|
||||
export const LlmStepStreamUsageEvent = z.object({
|
||||
type: z.literal("usage"),
|
||||
export const LlmStepStreamFinishStepEvent = z.object({
|
||||
type: z.literal("finish-step"),
|
||||
finishReason: z.enum(["stop", "tool-calls", "length", "content-filter", "error", "other", "unknown"]),
|
||||
usage: z.object({
|
||||
inputTokens: z.number().optional(),
|
||||
outputTokens: z.number().optional(),
|
||||
|
|
@ -42,6 +48,7 @@ export const LlmStepStreamUsageEvent = z.object({
|
|||
reasoningTokens: z.number().optional(),
|
||||
cachedInputTokens: z.number().optional(),
|
||||
}),
|
||||
providerOptions: ProviderOptions.optional(),
|
||||
});
|
||||
|
||||
export const LlmStepStreamEvent = z.union([
|
||||
|
|
@ -52,5 +59,5 @@ export const LlmStepStreamEvent = z.union([
|
|||
LlmStepStreamTextDeltaEvent,
|
||||
LlmStepStreamTextEndEvent,
|
||||
LlmStepStreamToolCallEvent,
|
||||
LlmStepStreamUsageEvent,
|
||||
LlmStepStreamFinishStepEvent,
|
||||
]);
|
||||
|
|
@ -1,13 +1,17 @@
|
|||
import { z } from "zod";
|
||||
|
||||
export const ProviderOptions = z.record(z.string(), z.record(z.string(), z.json()));
|
||||
|
||||
export const TextPart = z.object({
|
||||
type: z.literal("text"),
|
||||
text: z.string(),
|
||||
providerOptions: ProviderOptions.optional(),
|
||||
});
|
||||
|
||||
export const ReasoningPart = z.object({
|
||||
type: z.literal("reasoning"),
|
||||
text: z.string(),
|
||||
providerOptions: ProviderOptions.optional(),
|
||||
});
|
||||
|
||||
export const ToolCallPart = z.object({
|
||||
|
|
@ -15,6 +19,7 @@ export const ToolCallPart = z.object({
|
|||
toolCallId: z.string(),
|
||||
toolName: z.string(),
|
||||
arguments: z.any(),
|
||||
providerOptions: ProviderOptions.optional(),
|
||||
});
|
||||
|
||||
export const AssistantContentPart = z.union([
|
||||
|
|
@ -26,6 +31,7 @@ export const AssistantContentPart = z.union([
|
|||
export const UserMessage = z.object({
|
||||
role: z.literal("user"),
|
||||
content: z.string(),
|
||||
providerOptions: ProviderOptions.optional(),
|
||||
});
|
||||
|
||||
export const AssistantMessage = z.object({
|
||||
|
|
@ -34,11 +40,13 @@ export const AssistantMessage = z.object({
|
|||
z.string(),
|
||||
z.array(AssistantContentPart),
|
||||
]),
|
||||
providerOptions: ProviderOptions.optional(),
|
||||
});
|
||||
|
||||
export const SystemMessage = z.object({
|
||||
role: z.literal("system"),
|
||||
content: z.string(),
|
||||
providerOptions: ProviderOptions.optional(),
|
||||
});
|
||||
|
||||
export const ToolMessage = z.object({
|
||||
|
|
@ -46,6 +54,7 @@ export const ToolMessage = z.object({
|
|||
content: z.string(),
|
||||
toolCallId: z.string(),
|
||||
toolName: z.string(),
|
||||
providerOptions: ProviderOptions.optional(),
|
||||
});
|
||||
|
||||
export const Message = z.discriminatedUnion("role", [
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ import z from "zod";
|
|||
|
||||
export const Flavor = z.enum([
|
||||
"rowboat [free]",
|
||||
"aigateway",
|
||||
"anthropic",
|
||||
"google",
|
||||
"ollama",
|
||||
|
|
|
|||
|
|
@ -1,9 +1,9 @@
|
|||
import { jsonSchema, ModelMessage } from "ai";
|
||||
import { jsonSchema, ModelMessage, modelMessageSchema } from "ai";
|
||||
import fs from "fs";
|
||||
import path from "path";
|
||||
import { getModelConfig, WorkDir } from "../config/config.js";
|
||||
import { Agent, ToolAttachment } from "../entities/agent.js";
|
||||
import { AssistantContentPart, AssistantMessage, Message, MessageList, ToolCallPart, ToolMessage, UserMessage } from "../entities/message.js";
|
||||
import { AssistantContentPart, AssistantMessage, Message, MessageList, ProviderOptions, ToolCallPart, ToolMessage, UserMessage } from "../entities/message.js";
|
||||
import { runIdGenerator } from "./run-id-gen.js";
|
||||
import { LanguageModel, stepCountIs, streamText, tool, Tool, ToolSet } from "ai";
|
||||
import { z } from "zod";
|
||||
|
|
@ -90,6 +90,7 @@ export class StreamStepMessageBuilder {
|
|||
private parts: z.infer<typeof AssistantContentPart>[] = [];
|
||||
private textBuffer: string = "";
|
||||
private reasoningBuffer: string = "";
|
||||
private providerOptions: z.infer<typeof ProviderOptions> | undefined = undefined;
|
||||
|
||||
flushBuffers() {
|
||||
// skip reasoning
|
||||
|
|
@ -123,8 +124,12 @@ export class StreamStepMessageBuilder {
|
|||
toolCallId: event.toolCallId,
|
||||
toolName: event.toolName,
|
||||
arguments: event.input,
|
||||
providerOptions: event.providerOptions,
|
||||
});
|
||||
break;
|
||||
case "finish-step":
|
||||
this.providerOptions = event.providerOptions;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -133,6 +138,7 @@ export class StreamStepMessageBuilder {
|
|||
return {
|
||||
role: "assistant",
|
||||
content: this.parts,
|
||||
providerOptions: this.providerOptions,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
|
@ -173,12 +179,14 @@ export async function loadAgent(id: string): Promise<z.infer<typeof Agent>> {
|
|||
export function convertFromMessages(messages: z.infer<typeof Message>[]): ModelMessage[] {
|
||||
const result: ModelMessage[] = [];
|
||||
for (const msg of messages) {
|
||||
const { providerOptions } = msg;
|
||||
switch (msg.role) {
|
||||
case "assistant":
|
||||
if (typeof msg.content === 'string') {
|
||||
result.push({
|
||||
role: "assistant",
|
||||
content: msg.content,
|
||||
providerOptions,
|
||||
});
|
||||
} else {
|
||||
result.push({
|
||||
|
|
@ -195,9 +203,11 @@ export function convertFromMessages(messages: z.infer<typeof Message>[]): ModelM
|
|||
toolCallId: part.toolCallId,
|
||||
toolName: part.toolName,
|
||||
input: part.arguments,
|
||||
providerOptions: part.providerOptions,
|
||||
};
|
||||
}
|
||||
}),
|
||||
providerOptions,
|
||||
});
|
||||
}
|
||||
break;
|
||||
|
|
@ -205,12 +215,14 @@ export function convertFromMessages(messages: z.infer<typeof Message>[]): ModelM
|
|||
result.push({
|
||||
role: "system",
|
||||
content: msg.content,
|
||||
providerOptions,
|
||||
});
|
||||
break;
|
||||
case "user":
|
||||
result.push({
|
||||
role: "user",
|
||||
content: msg.content,
|
||||
providerOptions,
|
||||
});
|
||||
break;
|
||||
case "tool":
|
||||
|
|
@ -227,11 +239,13 @@ export function convertFromMessages(messages: z.infer<typeof Message>[]): ModelM
|
|||
},
|
||||
},
|
||||
],
|
||||
providerOptions,
|
||||
});
|
||||
break;
|
||||
}
|
||||
}
|
||||
return result;
|
||||
// doing this because: https://github.com/OpenRouterTeam/ai-sdk-provider/issues/262
|
||||
return JSON.parse(JSON.stringify(result));
|
||||
}
|
||||
|
||||
async function buildTools(agent: z.infer<typeof Agent>): Promise<ToolSet> {
|
||||
|
|
@ -446,7 +460,7 @@ export async function* streamAgent(state: AgentState): AsyncGenerator<z.infer<ty
|
|||
}
|
||||
|
||||
// if tool has been denied, deny
|
||||
if (state.deniedToolCallIds[toolCallId]) {
|
||||
if (state.deniedToolCallIds[toolCallId]) {
|
||||
yield* state.ingestAndLogAndYield({
|
||||
type: "message",
|
||||
message: {
|
||||
|
|
@ -561,7 +575,7 @@ export async function* streamAgent(state: AgentState): AsyncGenerator<z.infer<ty
|
|||
if (underlyingTool.type === "builtin" && underlyingTool.name === "executeCommand") {
|
||||
// if command is blocked, then seek permission
|
||||
if (isBlocked(part.arguments.command)) {
|
||||
yield *state.ingestAndLogAndYield({
|
||||
yield* state.ingestAndLogAndYield({
|
||||
type: "tool-permission-request",
|
||||
toolCall: part,
|
||||
subflow: [],
|
||||
|
|
@ -609,28 +623,33 @@ async function* streamLlm(
|
|||
case "reasoning-start":
|
||||
yield {
|
||||
type: "reasoning-start",
|
||||
providerOptions: event.providerMetadata,
|
||||
};
|
||||
break;
|
||||
case "reasoning-delta":
|
||||
yield {
|
||||
type: "reasoning-delta",
|
||||
delta: event.text,
|
||||
providerOptions: event.providerMetadata,
|
||||
};
|
||||
break;
|
||||
case "reasoning-end":
|
||||
yield {
|
||||
type: "reasoning-end",
|
||||
providerOptions: event.providerMetadata,
|
||||
};
|
||||
break;
|
||||
case "text-start":
|
||||
yield {
|
||||
type: "text-start",
|
||||
providerOptions: event.providerMetadata,
|
||||
};
|
||||
break;
|
||||
case "text-delta":
|
||||
yield {
|
||||
type: "text-delta",
|
||||
delta: event.text,
|
||||
providerOptions: event.providerMetadata,
|
||||
};
|
||||
break;
|
||||
case "tool-call":
|
||||
|
|
@ -639,12 +658,15 @@ async function* streamLlm(
|
|||
toolCallId: event.toolCallId,
|
||||
toolName: event.toolName,
|
||||
input: event.input,
|
||||
providerOptions: event.providerMetadata,
|
||||
};
|
||||
break;
|
||||
case "finish":
|
||||
case "finish-step":
|
||||
yield {
|
||||
type: "usage",
|
||||
usage: event.totalUsage,
|
||||
type: "finish-step",
|
||||
usage: event.usage,
|
||||
finishReason: event.finishReason,
|
||||
providerOptions: event.providerMetadata,
|
||||
};
|
||||
break;
|
||||
default:
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
import { ProviderV2 } from "@ai-sdk/provider";
|
||||
import { createGateway } from "ai";
|
||||
import { createOpenAI } from "@ai-sdk/openai";
|
||||
import { createGoogleGenerativeAI } from "@ai-sdk/google";
|
||||
import { createAnthropic } from "@ai-sdk/anthropic";
|
||||
|
|
@ -28,9 +29,9 @@ export async function getProvider(name: string = ""): Promise<ProviderV2> {
|
|||
const { apiKey, baseURL, headers } = providerConfig;
|
||||
switch (providerConfig.flavor) {
|
||||
case "rowboat [free]":
|
||||
providerMap[name] = createOpenAICompatible({
|
||||
name: "rowboat [free]",
|
||||
baseURL: "https://ai-gateway.rowboatlabs.com/v1",
|
||||
providerMap[name] = createGateway({
|
||||
apiKey: "rowboatx",
|
||||
baseURL: "https://ai-gateway.rowboatlabs.com/v1/ai",
|
||||
});
|
||||
break;
|
||||
case "openai":
|
||||
|
|
@ -40,6 +41,13 @@ export async function getProvider(name: string = ""): Promise<ProviderV2> {
|
|||
headers,
|
||||
});
|
||||
break;
|
||||
case "aigateway":
|
||||
providerMap[name] = createGateway({
|
||||
apiKey,
|
||||
baseURL,
|
||||
headers
|
||||
});
|
||||
break;
|
||||
case "anthropic":
|
||||
providerMap[name] = createAnthropic({
|
||||
apiKey,
|
||||
|
|
@ -65,7 +73,7 @@ export async function getProvider(name: string = ""): Promise<ProviderV2> {
|
|||
name,
|
||||
apiKey,
|
||||
baseURL : baseURL || "",
|
||||
headers
|
||||
headers,
|
||||
});
|
||||
break;
|
||||
case "openrouter":
|
||||
|
|
|
|||
|
|
@ -77,8 +77,8 @@ export class StreamRenderer {
|
|||
case "tool-call":
|
||||
this.onToolCall(event.toolCallId, event.toolName, event.input);
|
||||
break;
|
||||
case "usage":
|
||||
this.onUsage(event.usage);
|
||||
case "finish-step":
|
||||
this.onFinishStep(event.finishReason, event.usage);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
|
@ -219,13 +219,15 @@ export class StreamRenderer {
|
|||
this.write("\n");
|
||||
}
|
||||
|
||||
private onUsage(usage: {
|
||||
inputTokens?: number;
|
||||
outputTokens?: number;
|
||||
totalTokens?: number;
|
||||
reasoningTokens?: number;
|
||||
cachedInputTokens?: number;
|
||||
}) {
|
||||
private onFinishStep(
|
||||
finishReason: "stop" | "tool-calls" | "length" | "content-filter" | "error" | "other" | "unknown",
|
||||
usage: {
|
||||
inputTokens?: number;
|
||||
outputTokens?: number;
|
||||
totalTokens?: number;
|
||||
reasoningTokens?: number;
|
||||
cachedInputTokens?: number;
|
||||
}) {
|
||||
const parts: string[] = [];
|
||||
if (usage.inputTokens !== undefined) parts.push(`${this.dim("in:")} ${usage.inputTokens}`);
|
||||
if (usage.outputTokens !== undefined) parts.push(`${this.dim("out:")} ${usage.outputTokens}`);
|
||||
|
|
@ -234,8 +236,13 @@ export class StreamRenderer {
|
|||
if (usage.totalTokens !== undefined) parts.push(`${this.dim("total:")} ${this.bold(usage.totalTokens.toString())}`);
|
||||
const line = parts.join(this.dim(" | "));
|
||||
this.write("\n");
|
||||
this.write(this.dim("╭─ Usage\n"));
|
||||
this.write(this.dim("│ ") + line);
|
||||
this.write(this.bold("╭─ ") + this.bold("Finish"));
|
||||
this.write("\n");
|
||||
this.write(this.dim("│ ") + this.dim("reason: ") + finishReason);
|
||||
if (line.length) {
|
||||
this.write("\n");
|
||||
this.write(this.dim("│ ") + line);
|
||||
}
|
||||
this.write("\n");
|
||||
this.write(this.dim("╰─────────────\n"));
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue