autorun tools in api

This commit is contained in:
ramnique 2025-02-14 23:20:18 +05:30
parent 2c413bf165
commit 64312e2d5c
11 changed files with 264 additions and 135 deletions

View file

@ -4,7 +4,7 @@ build-backend = "hatchling.build"
[project]
name = "rowboat"
version = "1.0.1"
version = "1.0.2"
authors = [
{ name = "Your Name", email = "your.email@example.com" },
]

View file

@ -25,10 +25,14 @@ class Client:
self,
messages: List[ApiMessage],
state: Optional[Dict[str, Any]] = None,
skip_tool_calls: bool = False,
max_turns: int = 3
) -> ApiResponse:
request = ApiRequest(
messages=messages,
state=state
state=state,
skipToolCalls=skip_tool_calls,
maxTurns=max_turns
)
response = requests.post(self.base_url, headers=self.headers, data=request.model_dump_json())
@ -75,7 +79,8 @@ class Client:
messages: List[ApiMessage],
tools: Optional[Dict[str, Callable[..., str]]] = None,
state: Optional[Dict[str, Any]] = None,
max_turns: int = 3
max_turns: int = 3,
skip_tool_calls: bool = False
) -> Tuple[List[ApiMessage], Optional[Dict[str, Any]]]:
"""Stateless chat method that handles a single conversation turn with multiple tool call rounds"""
@ -91,7 +96,9 @@ class Client:
# call api
response_data = self._call_api(
messages=current_messages,
state=current_state
state=current_state,
skip_tool_calls=skip_tool_calls,
max_turns=max_turns
)
current_messages.extend(response_data.messages)
@ -128,11 +135,15 @@ class StatefulChat:
client: Client,
tools: Optional[Dict[str, Callable[..., str]]] = None,
system_prompt: Optional[str] = None,
max_turns: int = 3,
skip_tool_calls: bool = False
) -> None:
self.client = client
self.tools = tools
self.messages: List[ApiMessage] = []
self.state: Optional[Dict[str, Any]] = None
self.max_turns = max_turns
self.skip_tool_calls = skip_tool_calls
if system_prompt:
self.messages.append(SystemMessage(role='system', content=system_prompt))
@ -148,7 +159,9 @@ class StatefulChat:
new_messages, new_state = self.client.chat(
messages=self.messages,
tools=self.tools,
state=self.state
state=self.state,
max_turns=self.max_turns,
skip_tool_calls=self.skip_tool_calls
)
# Update internal state

View file

@ -48,6 +48,8 @@ ApiMessage = Union[
class ApiRequest(BaseModel):
messages: List[ApiMessage]
state: Any
skipToolCalls: Optional[bool] = None
maxTurns: Optional[int] = None
class ApiResponse(BaseModel):
messages: List[ApiMessage]

View file

@ -29,7 +29,7 @@ import { embeddingModel } from "../lib/embedding";
import { apiV1 } from "rowboat-shared";
import { zodToJsonSchema } from 'zod-to-json-schema';
import { Claims, getSession } from "@auth0/nextjs-auth0";
import { callClientToolWebhook, getAgenticApiResponse } from "../lib/utils";
import { callClientToolWebhook, getAgenticApiResponse, runRAGToolCall } from "../lib/utils";
import { assert } from "node:console";
import { check_query_limit } from "../lib/rate_limiting";
import { QueryLimitError } from "../lib/client_utils";
@ -313,77 +313,7 @@ export async function getInformationTool(
): Promise<z.infer<typeof GetInformationToolResult>> {
await projectAuthCheck(projectId);
// create embedding for question
const embedResult = await embed({
model: embeddingModel,
value: query,
});
// fetch all data sources for this project
const sources = await dataSourcesCollection.find({
projectId: projectId,
active: true,
}).toArray();
const validSourceIds = sources
.filter(s => sourceIds.includes(s._id.toString())) // id should be in sourceIds
.filter(s => s.active) // should be active
.map(s => s._id.toString());
// if no sources found, return empty response
if (validSourceIds.length === 0) {
return {
results: [],
};
}
// perform qdrant vector search
const qdrantResults = await qdrantClient.query("embeddings", {
query: embedResult.embedding,
filter: {
must: [
{ key: "projectId", match: { value: projectId } },
{ key: "sourceId", match: { any: validSourceIds } },
],
},
limit: k,
with_payload: true,
});
// if return type is chunks, return the chunks
let results = qdrantResults.points.map((point) => {
const { title, name, content, docId, sourceId } = point.payload as z.infer<typeof EmbeddingRecord>['payload'];
return {
title,
name,
content,
docId,
sourceId,
};
});
if (returnType === 'chunks') {
return {
results,
};
}
// otherwise, fetch the doc contents from mongodb
const docs = await dataSourceDocsCollection.find({
_id: { $in: results.map(r => new ObjectId(r.docId)) },
}).toArray();
// map the results to the docs
results = results.map(r => {
const doc = docs.find(d => d._id.toString() === r.docId);
return {
...r,
content: doc?.content || '',
};
});
return {
results,
};
return await runRAGToolCall(projectId, query, sourceIds, returnType, k);
}
export async function simulateUserResponse(

View file

@ -4,12 +4,11 @@ import { z } from "zod";
import { ObjectId } from "mongodb";
import { authCheck } from "../../utils";
import { ApiRequest, ApiResponse } from "../../../../lib/types/types";
import { convertFromAgenticApiToApiMessages } from "../../../../lib/types/agents_api_types";
import { convertFromApiToAgenticApiMessages } from "../../../../lib/types/agents_api_types";
import { convertWorkflowToAgenticAPI } from "../../../../lib/types/agents_api_types";
import { AgenticAPIChatRequest } from "../../../../lib/types/agents_api_types";
import { getAgenticApiResponse } from "../../../../lib/utils";
import { AgenticAPIChatRequest, AgenticAPIChatMessage, convertFromAgenticApiToApiMessages, convertFromApiToAgenticApiMessages, convertWorkflowToAgenticAPI } from "../../../../lib/types/agents_api_types";
import { getAgenticApiResponse, callClientToolWebhook, runRAGToolCall } from "../../../../lib/utils";
import { check_query_limit } from "../../../../lib/rate_limiting";
import { apiV1 } from "rowboat-shared";
import { PrefixLogger } from "../../../../lib/utils";
// get next turn / agent response
export async function POST(
@ -17,9 +16,14 @@ export async function POST(
{ params }: { params: Promise<{ projectId: string }> }
): Promise<Response> {
const { projectId } = await params;
const requestId = crypto.randomUUID();
const logger = new PrefixLogger(`[${requestId}]`);
logger.log(`Processing chat request for project ${projectId}`);
// check query limit
if (!await check_query_limit(projectId)) {
logger.log(`Query limit exceeded for project ${projectId}`);
return Response.json({ error: "Query limit exceeded" }, { status: 429 });
}
@ -29,10 +33,12 @@ export async function POST(
try {
body = await req.json();
} catch (e) {
logger.log(`Invalid JSON in request body: ${e}`);
return Response.json({ error: "Invalid JSON in request body" }, { status: 400 });
}
const result = ApiRequest.safeParse(body);
if (!result.success) {
logger.log(`Invalid request body: ${result.error.message}`);
return Response.json({ error: `Invalid request body: ${result.error.message}` }, { status: 400 });
}
const reqMessages = result.data.messages;
@ -43,9 +49,11 @@ export async function POST(
_id: projectId,
});
if (!project) {
logger.log(`Project ${projectId} not found`);
return Response.json({ error: "Project not found" }, { status: 404 });
}
if (!project.publishedWorkflowId) {
logger.log(`Project ${projectId} has no published workflow`);
return Response.json({ error: "Project has no published workflow" }, { status: 404 });
}
// fetch workflow
@ -54,27 +62,118 @@ export async function POST(
_id: new ObjectId(project.publishedWorkflowId),
});
if (!workflow) {
logger.log(`Workflow ${project.publishedWorkflowId} not found for project ${projectId}`);
return Response.json({ error: "Workflow not found" }, { status: 404 });
}
// get assistant response
const { agents, tools, prompts, startAgent } = convertWorkflowToAgenticAPI(workflow);
const request: z.infer<typeof AgenticAPIChatRequest> = {
messages: convertFromApiToAgenticApiMessages(reqMessages),
state: reqState ?? { last_agent_name: startAgent },
agents,
tools,
prompts,
startAgent,
};
console.log("turn: sending agentic request from /chat api", JSON.stringify(request, null, 2));
const { messages, state } = await getAgenticApiResponse(request);
const MAX_TURNS = result.data.maxTurns ?? 3;
let currentMessages = reqMessages;
let currentState: unknown = reqState ?? { last_agent_name: workflow.agents[0].name };
let turns = 0;
let hasToolCalls = false;
const response: z.infer<typeof ApiResponse> = {
messages: convertFromAgenticApiToApiMessages(messages),
state,
};
do {
hasToolCalls = false;
// get assistant response
const { agents, tools, prompts, startAgent } = convertWorkflowToAgenticAPI(workflow);
const request: z.infer<typeof AgenticAPIChatRequest> = {
messages: convertFromApiToAgenticApiMessages(currentMessages),
state: currentState,
agents,
tools,
prompts,
startAgent,
};
return Response.json(response);
console.log(`turn ${turns}: sending agentic request from /chat api`, JSON.stringify(request, null, 2));
logger.log(`Processing turn ${turns} for conversation`);
const { messages: agenticMessages, state } = await getAgenticApiResponse(request);
const newMessages = convertFromAgenticApiToApiMessages(agenticMessages);
currentState = state;
// if tool calls are to be skipped, return immediately
if (result.data.skipToolCalls) {
logger.log('Skipping tool calls as requested');
const responseBody: z.infer<typeof ApiResponse> = {
messages: newMessages,
state: currentState,
};
return Response.json(responseBody);
}
// get last message to check for tool calls
const lastMessage = newMessages[newMessages.length - 1];
if (lastMessage?.role === "assistant" &&
'tool_calls' in lastMessage &&
lastMessage.tool_calls?.length > 0) {
hasToolCalls = true;
const toolCallResultMessages: z.infer<typeof apiV1.ToolMessage>[] = [];
// Process tool calls
for (const toolCall of lastMessage.tool_calls) {
let result: unknown;
if (toolCall.function.name === "getArticleInfo") {
logger.log(`Running RAG tool call for agent ${lastMessage.agenticSender}`);
// find the source ids attached to this agent in the workflow
const agent = workflow.agents.find(a => a.name === lastMessage.agenticSender);
if (!agent) {
return Response.json({ error: "Agent not found" }, { status: 404 });
}
const sourceIds = agent.ragDataSources;
if (!sourceIds) {
return Response.json({ error: "Agent has no data sources" }, { status: 404 });
}
try {
result = await runRAGToolCall(projectId, toolCall.function.arguments, sourceIds, agent.ragReturnType, agent.ragK);
logger.log(`RAG tool call completed for agent ${lastMessage.agenticSender}`);
} catch (e) {
logger.log(`Error running RAG tool call: ${e}`);
return Response.json({ error: "Error running RAG tool call" }, { status: 500 });
}
} else {
logger.log(`Running client tool webhook for tool ${toolCall.function.name}`);
// run other tool calls by calling the client tool webhook
try {
result = await callClientToolWebhook(
toolCall,
currentMessages,
projectId,
);
logger.log(`Client tool webhook call completed for tool ${toolCall.function.name}`);
} catch (e) {
logger.log(`Error calling client tool webhook: ${e}`);
return Response.json({ error: "Error calling client tool webhook" }, { status: 500 });
}
}
toolCallResultMessages.push({
role: "tool",
tool_call_id: toolCall.id,
content: JSON.stringify(result),
tool_name: toolCall.function.name,
});
}
// Add new messages to the conversation
currentMessages = [...currentMessages, ...newMessages, ...toolCallResultMessages];
} else {
// No tool calls, just add the new messages
currentMessages = [...currentMessages, ...newMessages];
}
turns++;
if (turns >= MAX_TURNS && hasToolCalls) {
logger.log(`Max turns (${MAX_TURNS}) reached for conversation`);
return Response.json({ error: "Max turns reached" }, { status: 429 });
}
} while (hasToolCalls);
const responseBody: z.infer<typeof ApiResponse> = {
messages: currentMessages,
state: currentState,
};
return Response.json(responseBody);
});
}

View file

@ -107,6 +107,8 @@ export const ApiMessage = z.union([
export const ApiRequest = z.object({
messages: z.array(ApiMessage),
state: z.unknown(),
skipToolCalls: z.boolean().optional(),
maxTurns: z.number().optional(),
});
export const ApiResponse = z.object({

View file

@ -1,20 +1,26 @@
import { convertFromAgenticAPIChatMessages } from "./types/agents_api_types";
import { ClientToolCallRequest } from "./types/tool_types";
import { ClientToolCallJwt } from "./types/tool_types";
import { ClientToolCallJwt, GetInformationToolResult } from "./types/tool_types";
import { ClientToolCallRequestBody } from "./types/tool_types";
import { AgenticAPIChatResponse } from "./types/agents_api_types";
import { AgenticAPIChatRequest } from "./types/agents_api_types";
import { Workflow } from "./types/workflow_types";
import { Workflow, WorkflowAgent } from "./types/workflow_types";
import { AgenticAPIChatMessage } from "./types/agents_api_types";
import { z } from "zod";
import { projectsCollection } from "./mongodb";
import { dataSourceDocsCollection, dataSourcesCollection, projectsCollection } from "./mongodb";
import { apiV1 } from "rowboat-shared";
import { SignJWT } from "jose";
import crypto from "crypto";
import { ObjectId } from "mongodb";
import { embeddingModel } from "./embedding";
import { embed } from "ai";
import { qdrantClient } from "./qdrant";
import { EmbeddingRecord } from "./types/datasource_types";
import { ApiMessage } from "./types/types";
export async function callClientToolWebhook(
toolCall: z.infer<typeof apiV1.AssistantMessageWithToolCalls>['tool_calls'][number],
messages: z.infer<typeof apiV1.ChatMessage>[],
messages: z.infer<typeof ApiMessage>[],
projectId: string,
): Promise<unknown> {
const project = await projectsCollection.findOne({
@ -105,4 +111,110 @@ export async function getAgenticApiResponse(
state: result.state,
rawAPIResponse: result,
};
}
}
export async function runRAGToolCall(
projectId: string,
query: string,
sourceIds: string[],
returnType: z.infer<typeof WorkflowAgent>['ragReturnType'],
k: number,
): Promise<z.infer<typeof GetInformationToolResult>> {
// create embedding for question
const embedResult = await embed({
model: embeddingModel,
value: query,
});
// fetch all data sources for this project
const sources = await dataSourcesCollection.find({
projectId: projectId,
active: true,
}).toArray();
const validSourceIds = sources
.filter(s => sourceIds.includes(s._id.toString())) // id should be in sourceIds
.filter(s => s.active) // should be active
.map(s => s._id.toString());
// if no sources found, return empty response
if (validSourceIds.length === 0) {
return {
results: [],
};
}
// perform qdrant vector search
const qdrantResults = await qdrantClient.query("embeddings", {
query: embedResult.embedding,
filter: {
must: [
{ key: "projectId", match: { value: projectId } },
{ key: "sourceId", match: { any: validSourceIds } },
],
},
limit: k,
with_payload: true,
});
// if return type is chunks, return the chunks
let results = qdrantResults.points.map((point) => {
const { title, name, content, docId, sourceId } = point.payload as z.infer<typeof EmbeddingRecord>['payload'];
return {
title,
name,
content,
docId,
sourceId,
};
});
if (returnType === 'chunks') {
return {
results,
};
}
// otherwise, fetch the doc contents from mongodb
const docs = await dataSourceDocsCollection.find({
_id: { $in: results.map(r => new ObjectId(r.docId)) },
}).toArray();
// map the results to the docs
results = results.map(r => {
const doc = docs.find(d => d._id.toString() === r.docId);
return {
...r,
content: doc?.content || '',
};
});
return {
results,
};
}
// create a PrefixLogger class that wraps console.log with a prefix
// and allows chaining with a parent logger
export class PrefixLogger {
private prefix: string;
private parent: PrefixLogger | null;
constructor(prefix: string, parent: PrefixLogger | null = null) {
this.prefix = prefix;
this.parent = parent;
}
log(...args: any[]) {
const timestamp = new Date().toISOString();
const prefix = '[' + this.prefix + ']';
if (this.parent) {
this.parent.log(prefix, ...args);
} else {
console.log(timestamp, prefix, ...args);
}
}
child(childPrefix: string): PrefixLogger {
return new PrefixLogger(childPrefix, this);
}
}

View file

@ -167,9 +167,6 @@ export function AgentConfig({
<div className="flex flex-col gap-4 items-start">
<Label label="RAG (beta)" />
{agent.ragDataSources && agent.ragDataSources.length > 0 && <div className="text-xs text-red-500">
<sup>*</sup> RAG data sources are currently not supported in the API. (coming soon)
</div>}
<List
items={agent.ragDataSources?.map((source) => ({
id: source,

View file

@ -10,7 +10,7 @@ import { WithId } from 'mongodb';
import { embedMany } from 'ai';
import { embeddingModel } from '../lib/embedding';
import { qdrantClient } from '../lib/qdrant';
import { PrefixLogger } from './shared';
import { PrefixLogger } from "../lib/utils";
import { GoogleGenerativeAI } from "@google/generative-ai";
import { GetObjectCommand } from "@aws-sdk/client-s3";
import { uploadsS3Client } from '../lib/uploads_s3_client';

View file

@ -10,7 +10,7 @@ import { WithId } from 'mongodb';
import { embedMany } from 'ai';
import { embeddingModel } from '../lib/embedding';
import { qdrantClient } from '../lib/qdrant';
import { PrefixLogger } from './shared';
import { PrefixLogger } from "../lib/utils";
const firecrawl = new FirecrawlApp({ apiKey: process.env.FIRECRAWL_API_KEY });

View file

@ -1,26 +0,0 @@
// create a PrefixLogger class that wraps console.log with a prefix
// and allows chaining with a parent logger
export class PrefixLogger {
private prefix: string;
private parent: PrefixLogger | null;
constructor(prefix: string, parent: PrefixLogger | null = null) {
this.prefix = prefix;
this.parent = parent;
}
log(...args: any[]) {
const timestamp = new Date().toISOString();
const prefix = '[' + this.prefix + ']';
if (this.parent) {
this.parent.log(prefix, ...args);
} else {
console.log(timestamp, prefix, ...args);
}
}
child(childPrefix: string): PrefixLogger {
return new PrefixLogger(childPrefix, this);
}
}