diff --git a/apps/rowboat/app/actions/actions.ts b/apps/rowboat/app/actions/actions.ts index ed8ced54..2004a9eb 100644 --- a/apps/rowboat/app/actions/actions.ts +++ b/apps/rowboat/app/actions/actions.ts @@ -1,5 +1,5 @@ 'use server'; -import { convertFromAgenticAPIChatMessages } from "../lib/types/agents_api_types"; +import { AgenticAPIInitStreamResponse, convertFromAgenticAPIChatMessages } from "../lib/types/agents_api_types"; import { AgenticAPIChatRequest } from "../lib/types/agents_api_types"; import { WebpageCrawlResponse } from "../lib/types/tool_types"; import { webpagesCollection } from "../lib/mongodb"; @@ -7,7 +7,7 @@ import { z } from 'zod'; import FirecrawlApp, { ScrapeResponse } from '@mendable/firecrawl-js'; import { apiV1 } from "rowboat-shared"; import { Claims, getSession } from "@auth0/nextjs-auth0"; -import { getAgenticApiResponse } from "../lib/utils"; +import { getAgenticApiResponse, getAgenticResponseStreamId } from "../lib/utils"; import { check_query_limit } from "../lib/rate_limiting"; import { QueryLimitError } from "../lib/client_utils"; import { projectAuthCheck } from "./project_actions"; @@ -85,3 +85,13 @@ export async function getAssistantResponse(request: z.infer): Promise> { + await projectAuthCheck(request.projectId); + if (!await check_query_limit(request.projectId)) { + throw new QueryLimitError(); + } + + const response = await getAgenticResponseStreamId(request); + return response; +} \ No newline at end of file diff --git a/apps/rowboat/app/actions/mcp_actions.ts b/apps/rowboat/app/actions/mcp_actions.ts index eb2352f4..9ca6668f 100644 --- a/apps/rowboat/app/actions/mcp_actions.ts +++ b/apps/rowboat/app/actions/mcp_actions.ts @@ -6,6 +6,7 @@ import { SSEClientTransport } from "@modelcontextprotocol/sdk/client/sse.js"; import { projectAuthCheck } from "./project_actions"; import { projectsCollection } from "../lib/mongodb"; import { Project } from "../lib/types/project_types"; +import { MCPServer } from "../lib/types/types"; export async function fetchMcpTools(projectId: string): Promise[]> { await projectAuthCheck(projectId); @@ -71,4 +72,12 @@ export async function updateMcpServers(projectId: string, mcpServers: z.infer[]> { + await projectAuthCheck(projectId); + const project = await projectsCollection.findOne({ + _id: projectId, + }); + return project?.mcpServers ?? []; } \ No newline at end of file diff --git a/apps/rowboat/app/api/v1/stream-response/[streamId]/route.ts b/apps/rowboat/app/api/v1/stream-response/[streamId]/route.ts new file mode 100644 index 00000000..dac45025 --- /dev/null +++ b/apps/rowboat/app/api/v1/stream-response/[streamId]/route.ts @@ -0,0 +1,45 @@ +export async function GET(request: Request, { params }: { params: { streamId: string } }) { + // Replace with your actual upstream SSE endpoint. + const upstreamUrl = `${process.env.AGENTS_API_URL}/chat_stream/${params.streamId}`; + console.log('upstreamUrl', upstreamUrl); + + // Fetch the upstream SSE stream. + const upstreamResponse = await fetch(upstreamUrl, { + headers: { + 'Authorization': `Bearer ${process.env.AGENTS_API_KEY}`, + }, + cache: 'no-store', + }); + + // If the upstream request fails, return a 502 Bad Gateway. + if (!upstreamResponse.ok || !upstreamResponse.body) { + return new Response("Error connecting to upstream SSE stream", { status: 502 }); + } + + const reader = upstreamResponse.body.getReader(); + + const stream = new ReadableStream({ + async start(controller) { + try { + // Read from the upstream stream continuously. + while (true) { + const { done, value } = await reader.read(); + if (done) break; + // Immediately enqueue each received chunk. + controller.enqueue(value); + } + controller.close(); + } catch (error) { + controller.error(error); + } + }, + }); + + return new Response(stream, { + headers: { + "Content-Type": "text/event-stream", + "Cache-Control": "no-cache", + "Connection": "keep-alive", + }, + }); +} \ No newline at end of file diff --git a/apps/rowboat/app/lib/types/agents_api_types.ts b/apps/rowboat/app/lib/types/agents_api_types.ts index 42c652da..c4d111ec 100644 --- a/apps/rowboat/app/lib/types/agents_api_types.ts +++ b/apps/rowboat/app/lib/types/agents_api_types.ts @@ -64,6 +64,10 @@ export const AgenticAPIChatResponse = z.object({ state: z.unknown(), }); +export const AgenticAPIInitStreamResponse = z.object({ + streamId: z.string(), +}); + export function convertWorkflowToAgenticAPI(workflow: z.infer): { agents: z.infer[]; tools: z.infer[]; diff --git a/apps/rowboat/app/lib/utils.ts b/apps/rowboat/app/lib/utils.ts index 24568a89..2f44d4c0 100644 --- a/apps/rowboat/app/lib/utils.ts +++ b/apps/rowboat/app/lib/utils.ts @@ -1,4 +1,4 @@ -import { AgenticAPIChatResponse, AgenticAPIChatRequest, AgenticAPIChatMessage } from "./types/agents_api_types"; +import { AgenticAPIChatResponse, AgenticAPIChatRequest, AgenticAPIChatMessage, AgenticAPIInitStreamResponse } from "./types/agents_api_types"; import { z } from "zod"; import { generateObject } from "ai"; import { ApiMessage } from "./types/types"; @@ -35,6 +35,29 @@ export async function getAgenticApiResponse( }; } +export async function getAgenticResponseStreamId( + request: z.infer, +): Promise> { + // call agentic api + console.log(`sending agentic api init stream request`, JSON.stringify(request)); + const response = await fetch(process.env.AGENTS_API_URL + '/chat_stream_init', { + method: 'POST', + body: JSON.stringify(request), + headers: { + 'Content-Type': 'application/json', + 'Authorization': `Bearer ${process.env.AGENTS_API_KEY || 'test'}`, + }, + }); + if (!response.ok) { + console.error('Failed to call agentic init stream api', response); + throw new Error(`Failed to call agentic init stream api: ${response.statusText}`); + } + const responseJson = await response.json(); + console.log(`received agentic api init stream response`, JSON.stringify(responseJson)); + const result: z.infer = responseJson; + return result; +} + // create a PrefixLogger class that wraps console.log with a prefix // and allows chaining with a parent logger export class PrefixLogger { diff --git a/apps/rowboat/app/projects/[projectId]/playground/app.tsx b/apps/rowboat/app/projects/[projectId]/playground/app.tsx index 69a5202f..33c16663 100644 --- a/apps/rowboat/app/projects/[projectId]/playground/app.tsx +++ b/apps/rowboat/app/projects/[projectId]/playground/app.tsx @@ -48,10 +48,6 @@ export function App({ setCounter(counter + 1); } - if (hidden) { - return <>; - } - function handleNewChatButtonClick() { setCounter(counter + 1); setChat({ @@ -63,6 +59,10 @@ export function App({ }); } + if (hidden) { + return <>; + } + return ( []>(chat.messages); const [loadingAssistantResponse, setLoadingAssistantResponse] = useState(false); - const [loadingUserResponse, setLoadingUserResponse] = useState(false); - const [simulationComplete, setSimulationComplete] = useState(chat.simulationComplete || false); const [agenticState, setAgenticState] = useState(chat.agenticState || { last_agent_name: workflow.startAgent, }); @@ -51,6 +49,12 @@ export function Chat({ const [lastAgenticRequest, setLastAgenticRequest] = useState(null); const [lastAgenticResponse, setLastAgenticResponse] = useState(null); const [isProfileSelectorOpen, setIsProfileSelectorOpen] = useState(false); + const [optimisticMessages, setOptimisticMessages] = useState[]>(chat.messages); + + // reset optimistic messages when messages change + useEffect(() => { + setOptimisticMessages(messages); + }, [messages]); // collect published tool call results const toolCallResults: Record> = {}; @@ -59,6 +63,7 @@ export function Chat({ .forEach((message) => { toolCallResults[message.tool_call_id] = message; }); + console.log('toolCallResults', toolCallResults); function handleUserMessage(prompt: string) { const updatedMessages: z.infer[] = [...messages, { @@ -87,9 +92,12 @@ export function Chat({ } }, [messages, messageSubscriber]); - // get agent response + // get assistant response useEffect(() => { + console.log('stream useEffect called'); let ignore = false; + let eventSource: EventSource | null = null; + let msgs: z.infer[] = []; async function process() { setLoadingAssistantResponse(true); @@ -116,39 +124,76 @@ export function Chat({ setLastAgenticRequest(null); setLastAgenticResponse(null); + let streamId: string | null = null; try { - const response = await getAssistantResponse(request); + const response = await getAssistantResponseStreamId(request); if (ignore) { return; } - if (simulationComplete) { - return; - } - setLastAgenticRequest(response.rawRequest); - setLastAgenticResponse(response.rawResponse); - setMessages([...messages, ...response.messages.map((message) => ({ - ...message, - version: 'v1' as const, - chatId: '', - createdAt: new Date().toISOString(), - }))]); - setAgenticState(response.state); + streamId = response.streamId; } catch (err) { if (!ignore) { setFetchResponseError(`Failed to get assistant response: ${err instanceof Error ? err.message : 'Unknown error'}`); - } - } finally { - if (!ignore) { setLoadingAssistantResponse(false); } } + + if (ignore || !streamId) { + console.log('almost there', ignore, streamId); + return; + } + + // log the stream id + console.log('🔄 got assistant response', streamId); + + // read from SSE stream + eventSource = new EventSource(`/api/v1/stream-response/${streamId}`); + + eventSource.addEventListener("message", (event) => { + if (ignore) { + return; + } + + try { + const data = JSON.parse(event.data); + const msg = AgenticAPIChatMessage.parse(data); + const parsedMsg = convertFromAgenticAPIChatMessages([msg])[0]; + console.log('🔄 got assistant response chunk', parsedMsg); + msgs.push(parsedMsg); + setOptimisticMessages(prev => [...prev, parsedMsg]); + } catch (err) { + console.error('Failed to parse SSE message:', err); + setFetchResponseError(`Failed to parse SSE message: ${err instanceof Error ? err.message : 'Unknown error'}`); + setOptimisticMessages(messages); + } + }); + + eventSource.addEventListener('done', (event) => { + if (eventSource) { + eventSource.close(); + } + + console.log('🔄 got assistant response done', event.data); + const parsed: {state: unknown} = JSON.parse(event.data); + setAgenticState(parsed.state); + setMessages([...messages, ...msgs]); + setLoadingAssistantResponse(false); + }); + + eventSource.onerror = (error) => { + console.error('SSE Error:', error); + if (!ignore) { + setLoadingAssistantResponse(false); + setFetchResponseError('Stream connection failed'); + setOptimisticMessages(messages); + } + }; } - // if last message is not from role user - // or tool, return + // if last message is not a user message, return if (messages.length > 0) { const last = messages[messages.length - 1]; - if (last.role !== 'user' && last.role !== 'tool') { + if (last.role !== 'user') { return; } } @@ -162,8 +207,22 @@ export function Chat({ return () => { ignore = true; + console.log('stream useEffect cleanup called'); + if (eventSource) { + eventSource.close(); + } }; - }, [chat.simulated, messages, projectId, agenticState, workflow, fetchResponseError, systemMessage, simulationComplete, mcpServerUrls, toolWebhookUrl, testProfile]); + }, [ + messages, + projectId, + agenticState, + workflow, + systemMessage, + mcpServerUrls, + toolWebhookUrl, + testProfile, + fetchResponseError, + ]); const handleCopyChat = () => { const jsonString = JSON.stringify({ @@ -202,10 +261,9 @@ export function Chat({ /> )} - {!chat.simulated &&
+
-
} - {chat.simulated && !simulationComplete &&
- -
Simulating...
- -
} - {chat.simulated && simulationComplete &&

Simulation complete.

} +
; } \ No newline at end of file diff --git a/apps/rowboat/app/projects/[projectId]/playground/messages.tsx b/apps/rowboat/app/projects/[projectId]/playground/messages.tsx index f4e52110..cd4c979e 100644 --- a/apps/rowboat/app/projects/[projectId]/playground/messages.tsx +++ b/apps/rowboat/app/projects/[projectId]/playground/messages.tsx @@ -71,17 +71,6 @@ function AssistantMessageLoading() { ; } -function UserMessageLoading() { - return
-
- User -
-
- -
-
; -} - function ToolCalls({ toolCalls, results, @@ -101,20 +90,14 @@ function ToolCalls({ testProfile: z.infer | null; systemMessage: string | undefined; }) { - const resultsMap: Record> = {}; - return
{toolCalls.map(toolCall => { return })}
; @@ -123,21 +106,13 @@ function ToolCalls({ function ToolCall({ toolCall, result, - projectId, - messages, sender, workflow, - testProfile = null, - systemMessage, }: { toolCall: z.infer['tool_calls'][number]; result: z.infer | undefined; - projectId: string; - messages: z.infer[]; sender: string | null | undefined; workflow: z.infer; - testProfile: z.infer | null; - systemMessage: string | undefined; }) { let matchingWorkflowTool: z.infer | undefined; for (const tool of workflow.tools) { @@ -160,24 +135,6 @@ function ToolCall({ />; } -function ToolCallHeader({ - toolCall, - result, -}: { - toolCall: z.infer['tool_calls'][number]; - result: z.infer | undefined; -}) { - return
-
- {!result && } - {result && } -
- Function Call: {toolCall.function.name} -
-
-
; -} - function TransferToAgentToolCall({ result: availableResult, sender, @@ -211,7 +168,15 @@ function ClientToolCall({ return
{sender &&
{sender}
}
- +
+
+ {!availableResult && } + {availableResult && } +
+ Function Call: {toolCall.function.name} +
+
+
@@ -292,7 +257,6 @@ export function Messages({ messages, toolCallResults, loadingAssistantResponse, - loadingUserResponse, workflow, testProfile = null, systemMessage, @@ -302,7 +266,6 @@ export function Messages({ messages: z.infer[]; toolCallResults: Record>; loadingAssistantResponse: boolean; - loadingUserResponse: boolean; workflow: z.infer; testProfile: z.infer | null; systemMessage: string | undefined; @@ -314,7 +277,7 @@ export function Messages({ // scroll to bottom on new messages useEffect(() => { messagesEndRef.current?.scrollIntoView({ behavior: "smooth" }) - }, [messages, loadingAssistantResponse, loadingUserResponse]); + }, [messages, loadingAssistantResponse]); return
@@ -368,7 +331,6 @@ export function Messages({ return <>; })} {loadingAssistantResponse && } - {loadingUserResponse && }
; diff --git a/apps/rowboat/app/projects/[projectId]/workflow/app.tsx b/apps/rowboat/app/projects/[projectId]/workflow/app.tsx index 331e20b7..20b94273 100644 --- a/apps/rowboat/app/projects/[projectId]/workflow/app.tsx +++ b/apps/rowboat/app/projects/[projectId]/workflow/app.tsx @@ -9,17 +9,15 @@ import { WorkflowSelector } from "./workflow_selector"; import { Spinner } from "@heroui/react"; import { cloneWorkflow, createWorkflow, fetchPublishedWorkflowId, fetchWorkflow } from "../../../actions/workflow_actions"; import { listDataSources } from "../../../actions/datasource_actions"; +import { listMcpServers } from "@/app/actions/mcp_actions"; +import { getProjectConfig } from "@/app/actions/project_actions"; export function App({ projectId, useRag, - mcpServerUrls, - toolWebhookUrl, }: { projectId: string; useRag: boolean; - mcpServerUrls: Array>; - toolWebhookUrl: string; }) { const [selectorKey, setSelectorKey] = useState(0); const [workflow, setWorkflow] = useState> | null>(null); @@ -27,17 +25,23 @@ export function App({ const [dataSources, setDataSources] = useState>[] | null>(null); const [loading, setLoading] = useState(false); const [autoSelectIfOnlyOneWorkflow, setAutoSelectIfOnlyOneWorkflow] = useState(true); + const [mcpServerUrls, setMcpServerUrls] = useState>>([]); + const [toolWebhookUrl, setToolWebhookUrl] = useState(''); const handleSelect = useCallback(async (workflowId: string) => { setLoading(true); const workflow = await fetchWorkflow(projectId, workflowId); const publishedWorkflowId = await fetchPublishedWorkflowId(projectId); const dataSources = await listDataSources(projectId); + const mcpServers = await listMcpServers(projectId); + const projectConfig = await getProjectConfig(projectId); // Store the selected workflow ID in local storage localStorage.setItem(`lastWorkflowId_${projectId}`, workflowId); setWorkflow(workflow); setPublishedWorkflowId(publishedWorkflowId); setDataSources(dataSources); + setMcpServerUrls(mcpServers); + setToolWebhookUrl(projectConfig.webhookUrl ?? ''); setLoading(false); }, [projectId]); diff --git a/apps/rowboat/app/projects/[projectId]/workflow/page.tsx b/apps/rowboat/app/projects/[projectId]/workflow/page.tsx index be6d588c..09ec0573 100644 --- a/apps/rowboat/app/projects/[projectId]/workflow/page.tsx +++ b/apps/rowboat/app/projects/[projectId]/workflow/page.tsx @@ -13,18 +13,16 @@ export default async function Page({ }: { params: { projectId: string }; }) { + console.log('->>> workflow page being rendered'); const project = await projectsCollection.findOne({ _id: params.projectId, }); if (!project) { notFound(); } - const toolWebhookUrl = project.webhookUrl ?? ''; return ; } diff --git a/apps/rowboat_agents/pyproject.toml b/apps/rowboat_agents/pyproject.toml index 42334612..6bdc8bbe 100644 --- a/apps/rowboat_agents/pyproject.toml +++ b/apps/rowboat_agents/pyproject.toml @@ -80,6 +80,7 @@ python-docx = "^1.1.2" python-dotenv = "^1.0.1" pytz = "^2024.2" qdrant-client = "*" +Quart = "^0.20.0" RapidFuzz = "^3.11.0" redis = "^5.2.1" requests = "^2.32.3" diff --git a/apps/rowboat_agents/requirements.txt b/apps/rowboat_agents/requirements.txt index a7035d18..3d2be2cc 100644 --- a/apps/rowboat_agents/requirements.txt +++ b/apps/rowboat_agents/requirements.txt @@ -66,6 +66,7 @@ python-docx==1.1.2 python-dotenv==1.0.1 pytz==2024.2 qdrant-client +Quart==0.20.0 RapidFuzz==3.11.0 redis==5.2.1 requests==2.32.3 diff --git a/apps/rowboat_agents/src/app/main.py b/apps/rowboat_agents/src/app/main.py index ccc65d45..6526c9e3 100644 --- a/apps/rowboat_agents/src/app/main.py +++ b/apps/rowboat_agents/src/app/main.py @@ -1,13 +1,13 @@ -from flask import Flask, request, jsonify, Response +from quart import Quart, request, jsonify, Response from datetime import datetime from functools import wraps import os import redis import uuid import json -import asyncio from hypercorn.config import Config from hypercorn.asyncio import serve +import asyncio from src.graph.core import run_turn, run_turn_streamed from src.graph.tools import RAG_TOOL, CLOSE_CHAT_TOOL @@ -17,19 +17,34 @@ from pprint import pprint logger = common_logger redis_client = redis.from_url(os.environ.get('REDIS_URL', 'redis://localhost:6379')) -app = Flask(__name__) +app = Quart(__name__) + +# filter out agent transfer messages using a function +def is_agent_transfer_message(msg): + if (msg.get("role") == "assistant" and + msg.get("content") is None and + msg.get("tool_calls") is not None and + len(msg.get("tool_calls")) > 0 and + msg.get("tool_calls")[0].get("function").get("name") == "transfer_to_agent"): + return True + if (msg.get("role") == "tool" and + msg.get("tool_calls") is None and + msg.get("tool_call_id") is not None and + msg.get("tool_name") == "transfer_to_agent"): + return True + return False @app.route("/health", methods=["GET"]) -def health(): +async def health(): return jsonify({"status": "ok"}) @app.route("/") -def home(): +async def home(): return "Hello, World!" def require_api_key(f): @wraps(f) - def decorated(*args, **kwargs): + async def decorated(*args, **kwargs): auth_header = request.headers.get('Authorization') if not auth_header or not auth_header.startswith('Bearer '): return jsonify({'error': 'Missing or invalid authorization header'}), 401 @@ -39,16 +54,16 @@ def require_api_key(f): if actual and token != actual: return jsonify({'error': 'Invalid API key'}), 403 - return f(*args, **kwargs) + return await f(*args, **kwargs) return decorated @app.route("/chat", methods=["POST"]) @require_api_key -def chat(): +async def chat(): logger.info('='*100) logger.info(f"{'*'*100}Running server mode{'*'*100}") try: - data = request.get_json() + data = await request.get_json() logger.info('Complete request:') logger.info(data) logger.info('-'*100) @@ -56,9 +71,12 @@ def chat(): start_time = datetime.now() config = read_json_from_file("./configs/default_config.json") + # filter out agent transfer messages + input_messages = [msg for msg in data.get("messages", []) if not is_agent_transfer_message(msg)] + logger.info('Beginning turn') resp_messages, resp_tokens_used, resp_state = run_turn( - messages=data.get("messages", []), + messages=input_messages, start_agent_name=data.get("startAgent", ""), agent_configs=data.get("agents", []), tool_configs=data.get("tools", []), @@ -94,19 +112,27 @@ def chat(): @app.route("/chat_stream_init", methods=["POST"]) @require_api_key -def chat_stream_init(): +async def chat_stream_init(): # create a uuid for the stream stream_id = str(uuid.uuid4()) # store the request data in redis with 10 minute TTL - data = request.get_json() + data = await request.get_json() redis_client.setex(f"stream_request_{stream_id}", 600, json.dumps(data)) - return jsonify({"stream_id": stream_id}) + print('* stream init'*200) + + return jsonify({"streamId": stream_id}) + +def format_sse(data: dict, event: str = None) -> str: + msg = f"data: {json.dumps(data)}\n\n" + if event is not None: + msg = f"event: {event}\n{msg}" + return msg @app.route("/chat_stream/", methods=["GET"]) @require_api_key -def chat_stream(stream_id): +async def chat_stream(stream_id): # get the request data from redis request_data = redis_client.get(f"stream_request_{stream_id}") if not request_data: @@ -114,17 +140,18 @@ def chat_stream(stream_id): request_data = json.loads(request_data) config = read_json_from_file("./configs/default_config.json") + + # filter out agent transfer messages + input_messages = [msg for msg in request_data["messages"] if not is_agent_transfer_message(msg)] # Preprocess messages to handle null content and role issues - for msg in request_data["messages"]: - # Handle null content in assistant messages with tool calls + for msg in input_messages: if (msg.get("role") == "assistant" and msg.get("content") is None and msg.get("tool_calls") is not None and len(msg.get("tool_calls")) > 0): msg["content"] = "Calling tool" - # Handle role issues if msg.get("role") == "tool": msg["role"] = "developer" elif not msg.get("role"): @@ -135,12 +162,11 @@ def chat_stream(stream_id): print('*'*200) pprint(request_data) print('='*200) - - async def process_stream(): + async def generate(): try: async for event_type, event_data in run_turn_streamed( - messages=request_data.get("messages", []), + messages=input_messages, start_agent_name=request_data.get("startAgent", ""), agent_configs=request_data.get("agents", []), tool_configs=request_data.get("tools", []), @@ -153,43 +179,16 @@ def chat_stream(stream_id): print('*'*200) print("Yielding message:") print('*'*200) - to_yield = f"event: message\ndata: {json.dumps(event_data)}\n\n" - print(to_yield) - print('='*200) - yield to_yield + yield format_sse(event_data, "message") elif event_type == 'done': print('*'*200) print("Yielding done:") print('*'*200) - to_yield = f"event: done\ndata: {json.dumps(event_data)}\n\n" - print(to_yield) - print('='*200) - yield to_yield + yield format_sse(event_data, "done") except Exception as e: logger.error(f"Streaming error: {str(e)}") - yield f"event: error\ndata: {json.dumps({'error': str(e)})}\n\n" - - def generate(): - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - - try: - async def get_all_chunks(): - chunks = [] - async for chunk in process_stream(): - chunks.append(chunk) - return chunks - - chunks = loop.run_until_complete(get_all_chunks()) - for chunk in chunks: - yield chunk - - except Exception as e: - logger.error(f"Error in generate: {e}") - raise - finally: - loop.close() + yield format_sse({"error": str(e)}, "error") return Response(generate(), mimetype='text/event-stream') diff --git a/apps/rowboat_agents/src/graph/core.py b/apps/rowboat_agents/src/graph/core.py index 371e9aee..af00f870 100644 --- a/apps/rowboat_agents/src/graph/core.py +++ b/apps/rowboat_agents/src/graph/core.py @@ -1,6 +1,7 @@ from copy import deepcopy from datetime import datetime - +import json +import uuid import logging from .helpers.access import ( get_agent_by_name, @@ -285,16 +286,45 @@ async def run_turn_streamed( # Update current agent when it changes elif event.type == "agent_updated_stream_event": current_agent = event.new_agent + tool_call_id = str(uuid.uuid4()) + + # yield the transfer invocation message = { - 'content': f"Agent changed to {current_agent.name}", + 'content': None, 'role': 'assistant', 'sender': current_agent.name, - 'tool_calls': None, + 'tool_calls': [{ + 'function': { + 'name': 'transfer_to_agent', + 'arguments': json.dumps({ + 'assistant': event.new_agent.name + }) + }, + 'id': tool_call_id, + 'type': 'function' + }], 'tool_call_id': None, + 'tool_name': None, 'response_type': 'internal' } print("Yielding message: ", message) yield ('message', message) + + # yield the transfer result + message = { + 'content': json.dumps({ + 'assistant': event.new_agent.name + }), + 'role': 'tool', + 'sender': None, + 'tool_calls': None, + 'tool_call_id': tool_call_id, + 'tool_name': 'transfer_to_agent', + } + print("Yielding message: ", message) + yield ('message', message) + + current_agent = event.new_agent continue # Handle run items (tools, messages, etc)