diff --git a/apps/python-sdk/pyproject.toml b/apps/python-sdk/pyproject.toml index 55c5d849..202c1896 100644 --- a/apps/python-sdk/pyproject.toml +++ b/apps/python-sdk/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "rowboat" -version = "1.0.1" +version = "1.0.2" authors = [ { name = "Your Name", email = "your.email@example.com" }, ] diff --git a/apps/python-sdk/src/rowboat/client.py b/apps/python-sdk/src/rowboat/client.py index 39085cf7..9a203df2 100644 --- a/apps/python-sdk/src/rowboat/client.py +++ b/apps/python-sdk/src/rowboat/client.py @@ -25,10 +25,14 @@ class Client: self, messages: List[ApiMessage], state: Optional[Dict[str, Any]] = None, + skip_tool_calls: bool = False, + max_turns: int = 3 ) -> ApiResponse: request = ApiRequest( messages=messages, - state=state + state=state, + skipToolCalls=skip_tool_calls, + maxTurns=max_turns ) response = requests.post(self.base_url, headers=self.headers, data=request.model_dump_json()) @@ -75,7 +79,8 @@ class Client: messages: List[ApiMessage], tools: Optional[Dict[str, Callable[..., str]]] = None, state: Optional[Dict[str, Any]] = None, - max_turns: int = 3 + max_turns: int = 3, + skip_tool_calls: bool = False ) -> Tuple[List[ApiMessage], Optional[Dict[str, Any]]]: """Stateless chat method that handles a single conversation turn with multiple tool call rounds""" @@ -91,7 +96,9 @@ class Client: # call api response_data = self._call_api( messages=current_messages, - state=current_state + state=current_state, + skip_tool_calls=skip_tool_calls, + max_turns=max_turns ) current_messages.extend(response_data.messages) @@ -128,11 +135,15 @@ class StatefulChat: client: Client, tools: Optional[Dict[str, Callable[..., str]]] = None, system_prompt: Optional[str] = None, + max_turns: int = 3, + skip_tool_calls: bool = False ) -> None: self.client = client self.tools = tools self.messages: List[ApiMessage] = [] self.state: Optional[Dict[str, Any]] = None + self.max_turns = max_turns + self.skip_tool_calls = skip_tool_calls if system_prompt: self.messages.append(SystemMessage(role='system', content=system_prompt)) @@ -148,7 +159,9 @@ class StatefulChat: new_messages, new_state = self.client.chat( messages=self.messages, tools=self.tools, - state=self.state + state=self.state, + max_turns=self.max_turns, + skip_tool_calls=self.skip_tool_calls ) # Update internal state diff --git a/apps/python-sdk/src/rowboat/schema.py b/apps/python-sdk/src/rowboat/schema.py index 36edc393..1afa725a 100644 --- a/apps/python-sdk/src/rowboat/schema.py +++ b/apps/python-sdk/src/rowboat/schema.py @@ -48,6 +48,8 @@ ApiMessage = Union[ class ApiRequest(BaseModel): messages: List[ApiMessage] state: Any + skipToolCalls: Optional[bool] = None + maxTurns: Optional[int] = None class ApiResponse(BaseModel): messages: List[ApiMessage] diff --git a/apps/rowboat/app/actions/actions.ts b/apps/rowboat/app/actions/actions.ts index 006e3000..f2835f52 100644 --- a/apps/rowboat/app/actions/actions.ts +++ b/apps/rowboat/app/actions/actions.ts @@ -29,7 +29,7 @@ import { embeddingModel } from "../lib/embedding"; import { apiV1 } from "rowboat-shared"; import { zodToJsonSchema } from 'zod-to-json-schema'; import { Claims, getSession } from "@auth0/nextjs-auth0"; -import { callClientToolWebhook, getAgenticApiResponse } from "../lib/utils"; +import { callClientToolWebhook, getAgenticApiResponse, runRAGToolCall } from "../lib/utils"; import { assert } from "node:console"; import { check_query_limit } from "../lib/rate_limiting"; import { QueryLimitError } from "../lib/client_utils"; @@ -313,77 +313,7 @@ export async function getInformationTool( ): Promise> { await projectAuthCheck(projectId); - // create embedding for question - const embedResult = await embed({ - model: embeddingModel, - value: query, - }); - - // fetch all data sources for this project - const sources = await dataSourcesCollection.find({ - projectId: projectId, - active: true, - }).toArray(); - const validSourceIds = sources - .filter(s => sourceIds.includes(s._id.toString())) // id should be in sourceIds - .filter(s => s.active) // should be active - .map(s => s._id.toString()); - - // if no sources found, return empty response - if (validSourceIds.length === 0) { - return { - results: [], - }; - } - - // perform qdrant vector search - const qdrantResults = await qdrantClient.query("embeddings", { - query: embedResult.embedding, - filter: { - must: [ - { key: "projectId", match: { value: projectId } }, - { key: "sourceId", match: { any: validSourceIds } }, - ], - }, - limit: k, - with_payload: true, - }); - - // if return type is chunks, return the chunks - let results = qdrantResults.points.map((point) => { - const { title, name, content, docId, sourceId } = point.payload as z.infer['payload']; - return { - title, - name, - content, - docId, - sourceId, - }; - }); - - if (returnType === 'chunks') { - return { - results, - }; - } - - // otherwise, fetch the doc contents from mongodb - const docs = await dataSourceDocsCollection.find({ - _id: { $in: results.map(r => new ObjectId(r.docId)) }, - }).toArray(); - - // map the results to the docs - results = results.map(r => { - const doc = docs.find(d => d._id.toString() === r.docId); - return { - ...r, - content: doc?.content || '', - }; - }); - - return { - results, - }; + return await runRAGToolCall(projectId, query, sourceIds, returnType, k); } export async function simulateUserResponse( diff --git a/apps/rowboat/app/api/v1/[projectId]/chat/route.ts b/apps/rowboat/app/api/v1/[projectId]/chat/route.ts index 0489a19c..8b80ff88 100644 --- a/apps/rowboat/app/api/v1/[projectId]/chat/route.ts +++ b/apps/rowboat/app/api/v1/[projectId]/chat/route.ts @@ -4,12 +4,11 @@ import { z } from "zod"; import { ObjectId } from "mongodb"; import { authCheck } from "../../utils"; import { ApiRequest, ApiResponse } from "../../../../lib/types/types"; -import { convertFromAgenticApiToApiMessages } from "../../../../lib/types/agents_api_types"; -import { convertFromApiToAgenticApiMessages } from "../../../../lib/types/agents_api_types"; -import { convertWorkflowToAgenticAPI } from "../../../../lib/types/agents_api_types"; -import { AgenticAPIChatRequest } from "../../../../lib/types/agents_api_types"; -import { getAgenticApiResponse } from "../../../../lib/utils"; +import { AgenticAPIChatRequest, AgenticAPIChatMessage, convertFromAgenticApiToApiMessages, convertFromApiToAgenticApiMessages, convertWorkflowToAgenticAPI } from "../../../../lib/types/agents_api_types"; +import { getAgenticApiResponse, callClientToolWebhook, runRAGToolCall } from "../../../../lib/utils"; import { check_query_limit } from "../../../../lib/rate_limiting"; +import { apiV1 } from "rowboat-shared"; +import { PrefixLogger } from "../../../../lib/utils"; // get next turn / agent response export async function POST( @@ -17,9 +16,14 @@ export async function POST( { params }: { params: Promise<{ projectId: string }> } ): Promise { const { projectId } = await params; + const requestId = crypto.randomUUID(); + const logger = new PrefixLogger(`[${requestId}]`); + + logger.log(`Processing chat request for project ${projectId}`); // check query limit if (!await check_query_limit(projectId)) { + logger.log(`Query limit exceeded for project ${projectId}`); return Response.json({ error: "Query limit exceeded" }, { status: 429 }); } @@ -29,10 +33,12 @@ export async function POST( try { body = await req.json(); } catch (e) { + logger.log(`Invalid JSON in request body: ${e}`); return Response.json({ error: "Invalid JSON in request body" }, { status: 400 }); } const result = ApiRequest.safeParse(body); if (!result.success) { + logger.log(`Invalid request body: ${result.error.message}`); return Response.json({ error: `Invalid request body: ${result.error.message}` }, { status: 400 }); } const reqMessages = result.data.messages; @@ -43,9 +49,11 @@ export async function POST( _id: projectId, }); if (!project) { + logger.log(`Project ${projectId} not found`); return Response.json({ error: "Project not found" }, { status: 404 }); } if (!project.publishedWorkflowId) { + logger.log(`Project ${projectId} has no published workflow`); return Response.json({ error: "Project has no published workflow" }, { status: 404 }); } // fetch workflow @@ -54,27 +62,118 @@ export async function POST( _id: new ObjectId(project.publishedWorkflowId), }); if (!workflow) { + logger.log(`Workflow ${project.publishedWorkflowId} not found for project ${projectId}`); return Response.json({ error: "Workflow not found" }, { status: 404 }); } - // get assistant response - const { agents, tools, prompts, startAgent } = convertWorkflowToAgenticAPI(workflow); - const request: z.infer = { - messages: convertFromApiToAgenticApiMessages(reqMessages), - state: reqState ?? { last_agent_name: startAgent }, - agents, - tools, - prompts, - startAgent, - }; - console.log("turn: sending agentic request from /chat api", JSON.stringify(request, null, 2)); - const { messages, state } = await getAgenticApiResponse(request); + const MAX_TURNS = result.data.maxTurns ?? 3; + let currentMessages = reqMessages; + let currentState: unknown = reqState ?? { last_agent_name: workflow.agents[0].name }; + let turns = 0; + let hasToolCalls = false; - const response: z.infer = { - messages: convertFromAgenticApiToApiMessages(messages), - state, - }; + do { + hasToolCalls = false; + // get assistant response + const { agents, tools, prompts, startAgent } = convertWorkflowToAgenticAPI(workflow); + const request: z.infer = { + messages: convertFromApiToAgenticApiMessages(currentMessages), + state: currentState, + agents, + tools, + prompts, + startAgent, + }; - return Response.json(response); + console.log(`turn ${turns}: sending agentic request from /chat api`, JSON.stringify(request, null, 2)); + logger.log(`Processing turn ${turns} for conversation`); + const { messages: agenticMessages, state } = await getAgenticApiResponse(request); + + const newMessages = convertFromAgenticApiToApiMessages(agenticMessages); + currentState = state; + + // if tool calls are to be skipped, return immediately + if (result.data.skipToolCalls) { + logger.log('Skipping tool calls as requested'); + const responseBody: z.infer = { + messages: newMessages, + state: currentState, + }; + return Response.json(responseBody); + } + + // get last message to check for tool calls + const lastMessage = newMessages[newMessages.length - 1]; + if (lastMessage?.role === "assistant" && + 'tool_calls' in lastMessage && + lastMessage.tool_calls?.length > 0) { + hasToolCalls = true; + const toolCallResultMessages: z.infer[] = []; + + // Process tool calls + for (const toolCall of lastMessage.tool_calls) { + let result: unknown; + if (toolCall.function.name === "getArticleInfo") { + logger.log(`Running RAG tool call for agent ${lastMessage.agenticSender}`); + // find the source ids attached to this agent in the workflow + const agent = workflow.agents.find(a => a.name === lastMessage.agenticSender); + if (!agent) { + return Response.json({ error: "Agent not found" }, { status: 404 }); + } + const sourceIds = agent.ragDataSources; + if (!sourceIds) { + return Response.json({ error: "Agent has no data sources" }, { status: 404 }); + } + try { + result = await runRAGToolCall(projectId, toolCall.function.arguments, sourceIds, agent.ragReturnType, agent.ragK); + logger.log(`RAG tool call completed for agent ${lastMessage.agenticSender}`); + } catch (e) { + logger.log(`Error running RAG tool call: ${e}`); + return Response.json({ error: "Error running RAG tool call" }, { status: 500 }); + } + } else { + logger.log(`Running client tool webhook for tool ${toolCall.function.name}`); + // run other tool calls by calling the client tool webhook + try { + result = await callClientToolWebhook( + toolCall, + currentMessages, + projectId, + ); + logger.log(`Client tool webhook call completed for tool ${toolCall.function.name}`); + } catch (e) { + logger.log(`Error calling client tool webhook: ${e}`); + return Response.json({ error: "Error calling client tool webhook" }, { status: 500 }); + } + } + + toolCallResultMessages.push({ + role: "tool", + tool_call_id: toolCall.id, + content: JSON.stringify(result), + tool_name: toolCall.function.name, + }); + } + + // Add new messages to the conversation + currentMessages = [...currentMessages, ...newMessages, ...toolCallResultMessages]; + } else { + // No tool calls, just add the new messages + currentMessages = [...currentMessages, ...newMessages]; + } + + turns++; + if (turns >= MAX_TURNS && hasToolCalls) { + logger.log(`Max turns (${MAX_TURNS}) reached for conversation`); + return Response.json({ error: "Max turns reached" }, { status: 429 }); + } + + } while (hasToolCalls); + + const responseBody: z.infer = { + messages: currentMessages, + state: currentState, + }; + return Response.json(responseBody); }); } diff --git a/apps/rowboat/app/lib/types/types.ts b/apps/rowboat/app/lib/types/types.ts index 192f6980..8122cb31 100644 --- a/apps/rowboat/app/lib/types/types.ts +++ b/apps/rowboat/app/lib/types/types.ts @@ -107,6 +107,8 @@ export const ApiMessage = z.union([ export const ApiRequest = z.object({ messages: z.array(ApiMessage), state: z.unknown(), + skipToolCalls: z.boolean().optional(), + maxTurns: z.number().optional(), }); export const ApiResponse = z.object({ diff --git a/apps/rowboat/app/lib/utils.ts b/apps/rowboat/app/lib/utils.ts index 1733f6b7..1ad35304 100644 --- a/apps/rowboat/app/lib/utils.ts +++ b/apps/rowboat/app/lib/utils.ts @@ -1,20 +1,26 @@ import { convertFromAgenticAPIChatMessages } from "./types/agents_api_types"; import { ClientToolCallRequest } from "./types/tool_types"; -import { ClientToolCallJwt } from "./types/tool_types"; +import { ClientToolCallJwt, GetInformationToolResult } from "./types/tool_types"; import { ClientToolCallRequestBody } from "./types/tool_types"; import { AgenticAPIChatResponse } from "./types/agents_api_types"; import { AgenticAPIChatRequest } from "./types/agents_api_types"; -import { Workflow } from "./types/workflow_types"; +import { Workflow, WorkflowAgent } from "./types/workflow_types"; import { AgenticAPIChatMessage } from "./types/agents_api_types"; import { z } from "zod"; -import { projectsCollection } from "./mongodb"; +import { dataSourceDocsCollection, dataSourcesCollection, projectsCollection } from "./mongodb"; import { apiV1 } from "rowboat-shared"; import { SignJWT } from "jose"; import crypto from "crypto"; +import { ObjectId } from "mongodb"; +import { embeddingModel } from "./embedding"; +import { embed } from "ai"; +import { qdrantClient } from "./qdrant"; +import { EmbeddingRecord } from "./types/datasource_types"; +import { ApiMessage } from "./types/types"; export async function callClientToolWebhook( toolCall: z.infer['tool_calls'][number], - messages: z.infer[], + messages: z.infer[], projectId: string, ): Promise { const project = await projectsCollection.findOne({ @@ -105,4 +111,110 @@ export async function getAgenticApiResponse( state: result.state, rawAPIResponse: result, }; -} \ No newline at end of file +} + +export async function runRAGToolCall( + projectId: string, + query: string, + sourceIds: string[], + returnType: z.infer['ragReturnType'], + k: number, +): Promise> { + // create embedding for question + const embedResult = await embed({ + model: embeddingModel, + value: query, + }); + + // fetch all data sources for this project + const sources = await dataSourcesCollection.find({ + projectId: projectId, + active: true, + }).toArray(); + const validSourceIds = sources + .filter(s => sourceIds.includes(s._id.toString())) // id should be in sourceIds + .filter(s => s.active) // should be active + .map(s => s._id.toString()); + + // if no sources found, return empty response + if (validSourceIds.length === 0) { + return { + results: [], + }; + } + + // perform qdrant vector search + const qdrantResults = await qdrantClient.query("embeddings", { + query: embedResult.embedding, + filter: { + must: [ + { key: "projectId", match: { value: projectId } }, + { key: "sourceId", match: { any: validSourceIds } }, + ], + }, + limit: k, + with_payload: true, + }); + + // if return type is chunks, return the chunks + let results = qdrantResults.points.map((point) => { + const { title, name, content, docId, sourceId } = point.payload as z.infer['payload']; + return { + title, + name, + content, + docId, + sourceId, + }; + }); + + if (returnType === 'chunks') { + return { + results, + }; + } + + // otherwise, fetch the doc contents from mongodb + const docs = await dataSourceDocsCollection.find({ + _id: { $in: results.map(r => new ObjectId(r.docId)) }, + }).toArray(); + + // map the results to the docs + results = results.map(r => { + const doc = docs.find(d => d._id.toString() === r.docId); + return { + ...r, + content: doc?.content || '', + }; + }); + + return { + results, + }; +} +// create a PrefixLogger class that wraps console.log with a prefix +// and allows chaining with a parent logger +export class PrefixLogger { + private prefix: string; + private parent: PrefixLogger | null; + + constructor(prefix: string, parent: PrefixLogger | null = null) { + this.prefix = prefix; + this.parent = parent; + } + + log(...args: any[]) { + const timestamp = new Date().toISOString(); + const prefix = '[' + this.prefix + ']'; + + if (this.parent) { + this.parent.log(prefix, ...args); + } else { + console.log(timestamp, prefix, ...args); + } + } + + child(childPrefix: string): PrefixLogger { + return new PrefixLogger(childPrefix, this); + } +} diff --git a/apps/rowboat/app/projects/[projectId]/workflow/agent_config.tsx b/apps/rowboat/app/projects/[projectId]/workflow/agent_config.tsx index 6e9b0a75..47f441df 100644 --- a/apps/rowboat/app/projects/[projectId]/workflow/agent_config.tsx +++ b/apps/rowboat/app/projects/[projectId]/workflow/agent_config.tsx @@ -167,9 +167,6 @@ export function AgentConfig({