Next.js changes for playground streaming

This commit is contained in:
ramnique 2025-03-25 01:42:22 +05:30 committed by Ramnique Singh
parent 24efe0e887
commit 77b53696b6
14 changed files with 290 additions and 160 deletions

View file

@ -1,5 +1,5 @@
'use server';
import { convertFromAgenticAPIChatMessages } from "../lib/types/agents_api_types";
import { AgenticAPIInitStreamResponse, convertFromAgenticAPIChatMessages } from "../lib/types/agents_api_types";
import { AgenticAPIChatRequest } from "../lib/types/agents_api_types";
import { WebpageCrawlResponse } from "../lib/types/tool_types";
import { webpagesCollection } from "../lib/mongodb";
@ -7,7 +7,7 @@ import { z } from 'zod';
import FirecrawlApp, { ScrapeResponse } from '@mendable/firecrawl-js';
import { apiV1 } from "rowboat-shared";
import { Claims, getSession } from "@auth0/nextjs-auth0";
import { getAgenticApiResponse } from "../lib/utils";
import { getAgenticApiResponse, getAgenticResponseStreamId } from "../lib/utils";
import { check_query_limit } from "../lib/rate_limiting";
import { QueryLimitError } from "../lib/client_utils";
import { projectAuthCheck } from "./project_actions";
@ -85,3 +85,13 @@ export async function getAssistantResponse(request: z.infer<typeof AgenticAPICha
rawResponse: response.rawAPIResponse,
};
}
export async function getAssistantResponseStreamId(request: z.infer<typeof AgenticAPIChatRequest>): Promise<z.infer<typeof AgenticAPIInitStreamResponse>> {
await projectAuthCheck(request.projectId);
if (!await check_query_limit(request.projectId)) {
throw new QueryLimitError();
}
const response = await getAgenticResponseStreamId(request);
return response;
}

View file

@ -6,6 +6,7 @@ import { SSEClientTransport } from "@modelcontextprotocol/sdk/client/sse.js";
import { projectAuthCheck } from "./project_actions";
import { projectsCollection } from "../lib/mongodb";
import { Project } from "../lib/types/project_types";
import { MCPServer } from "../lib/types/types";
export async function fetchMcpTools(projectId: string): Promise<z.infer<typeof WorkflowTool>[]> {
await projectAuthCheck(projectId);
@ -71,4 +72,12 @@ export async function updateMcpServers(projectId: string, mcpServers: z.infer<ty
await projectsCollection.updateOne({
_id: projectId,
}, { $set: { mcpServers } });
}
export async function listMcpServers(projectId: string): Promise<z.infer<typeof MCPServer>[]> {
await projectAuthCheck(projectId);
const project = await projectsCollection.findOne({
_id: projectId,
});
return project?.mcpServers ?? [];
}

View file

@ -0,0 +1,45 @@
export async function GET(request: Request, { params }: { params: { streamId: string } }) {
// Replace with your actual upstream SSE endpoint.
const upstreamUrl = `${process.env.AGENTS_API_URL}/chat_stream/${params.streamId}`;
console.log('upstreamUrl', upstreamUrl);
// Fetch the upstream SSE stream.
const upstreamResponse = await fetch(upstreamUrl, {
headers: {
'Authorization': `Bearer ${process.env.AGENTS_API_KEY}`,
},
cache: 'no-store',
});
// If the upstream request fails, return a 502 Bad Gateway.
if (!upstreamResponse.ok || !upstreamResponse.body) {
return new Response("Error connecting to upstream SSE stream", { status: 502 });
}
const reader = upstreamResponse.body.getReader();
const stream = new ReadableStream({
async start(controller) {
try {
// Read from the upstream stream continuously.
while (true) {
const { done, value } = await reader.read();
if (done) break;
// Immediately enqueue each received chunk.
controller.enqueue(value);
}
controller.close();
} catch (error) {
controller.error(error);
}
},
});
return new Response(stream, {
headers: {
"Content-Type": "text/event-stream",
"Cache-Control": "no-cache",
"Connection": "keep-alive",
},
});
}

View file

@ -64,6 +64,10 @@ export const AgenticAPIChatResponse = z.object({
state: z.unknown(),
});
export const AgenticAPIInitStreamResponse = z.object({
streamId: z.string(),
});
export function convertWorkflowToAgenticAPI(workflow: z.infer<typeof Workflow>): {
agents: z.infer<typeof AgenticAPIAgent>[];
tools: z.infer<typeof AgenticAPITool>[];

View file

@ -1,4 +1,4 @@
import { AgenticAPIChatResponse, AgenticAPIChatRequest, AgenticAPIChatMessage } from "./types/agents_api_types";
import { AgenticAPIChatResponse, AgenticAPIChatRequest, AgenticAPIChatMessage, AgenticAPIInitStreamResponse } from "./types/agents_api_types";
import { z } from "zod";
import { generateObject } from "ai";
import { ApiMessage } from "./types/types";
@ -35,6 +35,29 @@ export async function getAgenticApiResponse(
};
}
export async function getAgenticResponseStreamId(
request: z.infer<typeof AgenticAPIChatRequest>,
): Promise<z.infer<typeof AgenticAPIInitStreamResponse>> {
// call agentic api
console.log(`sending agentic api init stream request`, JSON.stringify(request));
const response = await fetch(process.env.AGENTS_API_URL + '/chat_stream_init', {
method: 'POST',
body: JSON.stringify(request),
headers: {
'Content-Type': 'application/json',
'Authorization': `Bearer ${process.env.AGENTS_API_KEY || 'test'}`,
},
});
if (!response.ok) {
console.error('Failed to call agentic init stream api', response);
throw new Error(`Failed to call agentic init stream api: ${response.statusText}`);
}
const responseJson = await response.json();
console.log(`received agentic api init stream response`, JSON.stringify(responseJson));
const result: z.infer<typeof AgenticAPIInitStreamResponse> = responseJson;
return result;
}
// create a PrefixLogger class that wraps console.log with a prefix
// and allows chaining with a parent logger
export class PrefixLogger {

View file

@ -48,10 +48,6 @@ export function App({
setCounter(counter + 1);
}
if (hidden) {
return <></>;
}
function handleNewChatButtonClick() {
setCounter(counter + 1);
setChat({
@ -63,6 +59,10 @@ export function App({
});
}
if (hidden) {
return <></>;
}
return (
<Pane
title="PLAYGROUND"

View file

@ -1,10 +1,10 @@
'use client';
import { getAssistantResponse } from "../../../actions/actions";
import { useEffect, useState } from "react";
import { getAssistantResponseStreamId } from "../../../actions/actions";
import { useEffect, useOptimistic, useState } from "react";
import { Messages } from "./messages";
import z from "zod";
import { MCPServer, PlaygroundChat } from "../../../lib/types/types";
import { convertToAgenticAPIChatMessages } from "../../../lib/types/agents_api_types";
import { AgenticAPIChatMessage, convertFromAgenticAPIChatMessages, convertToAgenticAPIChatMessages } from "../../../lib/types/agents_api_types";
import { convertWorkflowToAgenticAPI } from "../../../lib/types/agents_api_types";
import { AgenticAPIChatRequest } from "../../../lib/types/agents_api_types";
import { Workflow } from "../../../lib/types/workflow_types";
@ -22,7 +22,7 @@ export function Chat({
projectId,
workflow,
messageSubscriber,
testProfile=null,
testProfile = null,
onTestProfileChange,
systemMessage,
onSystemMessageChange,
@ -42,8 +42,6 @@ export function Chat({
}) {
const [messages, setMessages] = useState<z.infer<typeof apiV1.ChatMessage>[]>(chat.messages);
const [loadingAssistantResponse, setLoadingAssistantResponse] = useState<boolean>(false);
const [loadingUserResponse, setLoadingUserResponse] = useState<boolean>(false);
const [simulationComplete, setSimulationComplete] = useState<boolean>(chat.simulationComplete || false);
const [agenticState, setAgenticState] = useState<unknown>(chat.agenticState || {
last_agent_name: workflow.startAgent,
});
@ -51,6 +49,12 @@ export function Chat({
const [lastAgenticRequest, setLastAgenticRequest] = useState<unknown | null>(null);
const [lastAgenticResponse, setLastAgenticResponse] = useState<unknown | null>(null);
const [isProfileSelectorOpen, setIsProfileSelectorOpen] = useState(false);
const [optimisticMessages, setOptimisticMessages] = useState<z.infer<typeof apiV1.ChatMessage>[]>(chat.messages);
// reset optimistic messages when messages change
useEffect(() => {
setOptimisticMessages(messages);
}, [messages]);
// collect published tool call results
const toolCallResults: Record<string, z.infer<typeof apiV1.ToolMessage>> = {};
@ -59,6 +63,7 @@ export function Chat({
.forEach((message) => {
toolCallResults[message.tool_call_id] = message;
});
console.log('toolCallResults', toolCallResults);
function handleUserMessage(prompt: string) {
const updatedMessages: z.infer<typeof apiV1.ChatMessage>[] = [...messages, {
@ -87,9 +92,12 @@ export function Chat({
}
}, [messages, messageSubscriber]);
// get agent response
// get assistant response
useEffect(() => {
console.log('stream useEffect called');
let ignore = false;
let eventSource: EventSource | null = null;
let msgs: z.infer<typeof apiV1.ChatMessage>[] = [];
async function process() {
setLoadingAssistantResponse(true);
@ -116,39 +124,76 @@ export function Chat({
setLastAgenticRequest(null);
setLastAgenticResponse(null);
let streamId: string | null = null;
try {
const response = await getAssistantResponse(request);
const response = await getAssistantResponseStreamId(request);
if (ignore) {
return;
}
if (simulationComplete) {
return;
}
setLastAgenticRequest(response.rawRequest);
setLastAgenticResponse(response.rawResponse);
setMessages([...messages, ...response.messages.map((message) => ({
...message,
version: 'v1' as const,
chatId: '',
createdAt: new Date().toISOString(),
}))]);
setAgenticState(response.state);
streamId = response.streamId;
} catch (err) {
if (!ignore) {
setFetchResponseError(`Failed to get assistant response: ${err instanceof Error ? err.message : 'Unknown error'}`);
}
} finally {
if (!ignore) {
setLoadingAssistantResponse(false);
}
}
if (ignore || !streamId) {
console.log('almost there', ignore, streamId);
return;
}
// log the stream id
console.log('🔄 got assistant response', streamId);
// read from SSE stream
eventSource = new EventSource(`/api/v1/stream-response/${streamId}`);
eventSource.addEventListener("message", (event) => {
if (ignore) {
return;
}
try {
const data = JSON.parse(event.data);
const msg = AgenticAPIChatMessage.parse(data);
const parsedMsg = convertFromAgenticAPIChatMessages([msg])[0];
console.log('🔄 got assistant response chunk', parsedMsg);
msgs.push(parsedMsg);
setOptimisticMessages(prev => [...prev, parsedMsg]);
} catch (err) {
console.error('Failed to parse SSE message:', err);
setFetchResponseError(`Failed to parse SSE message: ${err instanceof Error ? err.message : 'Unknown error'}`);
setOptimisticMessages(messages);
}
});
eventSource.addEventListener('done', (event) => {
if (eventSource) {
eventSource.close();
}
console.log('🔄 got assistant response done', event.data);
const parsed: {state: unknown} = JSON.parse(event.data);
setAgenticState(parsed.state);
setMessages([...messages, ...msgs]);
setLoadingAssistantResponse(false);
});
eventSource.onerror = (error) => {
console.error('SSE Error:', error);
if (!ignore) {
setLoadingAssistantResponse(false);
setFetchResponseError('Stream connection failed');
setOptimisticMessages(messages);
}
};
}
// if last message is not from role user
// or tool, return
// if last message is not a user message, return
if (messages.length > 0) {
const last = messages[messages.length - 1];
if (last.role !== 'user' && last.role !== 'tool') {
if (last.role !== 'user') {
return;
}
}
@ -162,8 +207,22 @@ export function Chat({
return () => {
ignore = true;
console.log('stream useEffect cleanup called');
if (eventSource) {
eventSource.close();
}
};
}, [chat.simulated, messages, projectId, agenticState, workflow, fetchResponseError, systemMessage, simulationComplete, mcpServerUrls, toolWebhookUrl, testProfile]);
}, [
messages,
projectId,
agenticState,
workflow,
systemMessage,
mcpServerUrls,
toolWebhookUrl,
testProfile,
fetchResponseError,
]);
const handleCopyChat = () => {
const jsonString = JSON.stringify({
@ -202,10 +261,9 @@ export function Chat({
/>
<Messages
projectId={projectId}
messages={messages}
messages={optimisticMessages}
toolCallResults={toolCallResults}
loadingAssistantResponse={loadingAssistantResponse}
loadingUserResponse={loadingUserResponse}
workflow={workflow}
testProfile={testProfile}
systemMessage={systemMessage}
@ -226,26 +284,12 @@ export function Chat({
</Button>
</div>
)}
{!chat.simulated && <div className="max-w-[768px] mx-auto">
<div className="max-w-[768px] mx-auto">
<ComposeBox
handleUserMessage={handleUserMessage}
messages={messages}
/>
</div>}
{chat.simulated && !simulationComplete && <div className="p-2 bg-gray-50 border border-gray-200 flex items-center justify-center gap-2">
<Spinner size="sm" />
<div className="text-sm text-gray-500 animate-pulse">Simulating...</div>
<Button
size="sm"
color="danger"
onPress={() => {
setSimulationComplete(true);
}}
>
Stop
</Button>
</div>}
{chat.simulated && simulationComplete && <p className="text-center text-sm">Simulation complete.</p>}
</div>
</div>
</div>;
}

View file

@ -71,17 +71,6 @@ function AssistantMessageLoading() {
</div>;
}
function UserMessageLoading() {
return <div className="self-end ml-[30%] flex flex-col">
<div className="text-right text-gray-500 dark:text-gray-400 text-sm mr-3">
User
</div>
<div className="bg-gray-100 dark:bg-gray-800 p-3 rounded-lg rounded-br-none animate-pulse w-20 text-gray-800 dark:text-gray-200">
<Spinner size="sm" />
</div>
</div>;
}
function ToolCalls({
toolCalls,
results,
@ -101,20 +90,14 @@ function ToolCalls({
testProfile: z.infer<typeof TestProfile> | null;
systemMessage: string | undefined;
}) {
const resultsMap: Record<string, z.infer<typeof apiV1.ToolMessage>> = {};
return <div className="flex flex-col gap-4">
{toolCalls.map(toolCall => {
return <ToolCall
key={toolCall.id}
toolCall={toolCall}
result={results[toolCall.id]}
projectId={projectId}
messages={messages}
sender={sender}
workflow={workflow}
testProfile={testProfile}
systemMessage={systemMessage}
/>
})}
</div>;
@ -123,21 +106,13 @@ function ToolCalls({
function ToolCall({
toolCall,
result,
projectId,
messages,
sender,
workflow,
testProfile = null,
systemMessage,
}: {
toolCall: z.infer<typeof apiV1.AssistantMessageWithToolCalls>['tool_calls'][number];
result: z.infer<typeof apiV1.ToolMessage> | undefined;
projectId: string;
messages: z.infer<typeof apiV1.ChatMessage>[];
sender: string | null | undefined;
workflow: z.infer<typeof Workflow>;
testProfile: z.infer<typeof TestProfile> | null;
systemMessage: string | undefined;
}) {
let matchingWorkflowTool: z.infer<typeof WorkflowTool> | undefined;
for (const tool of workflow.tools) {
@ -160,24 +135,6 @@ function ToolCall({
/>;
}
function ToolCallHeader({
toolCall,
result,
}: {
toolCall: z.infer<typeof apiV1.AssistantMessageWithToolCalls>['tool_calls'][number];
result: z.infer<typeof apiV1.ToolMessage> | undefined;
}) {
return <div className="flex flex-col gap-1">
<div className='shrink-0 flex gap-2 items-center'>
{!result && <Spinner size="sm" />}
{result && <CircleCheckIcon size={16} />}
<div className='font-semibold text-sm'>
Function Call: <code className='bg-gray-100 dark:bg-neutral-800 px-2 py-0.5 rounded font-mono'>{toolCall.function.name}</code>
</div>
</div>
</div>;
}
function TransferToAgentToolCall({
result: availableResult,
sender,
@ -211,7 +168,15 @@ function ClientToolCall({
return <div className="flex flex-col gap-1">
{sender && <div className='text-gray-500 text-sm ml-3'>{sender}</div>}
<div className='border border-gray-300 p-2 pt-2 rounded-lg rounded-bl-none flex flex-col gap-2 mr-[30%]'>
<ToolCallHeader toolCall={toolCall} result={availableResult} />
<div className="flex flex-col gap-1">
<div className='shrink-0 flex gap-2 items-center'>
{!availableResult && <Spinner size="sm" />}
{availableResult && <CircleCheckIcon size={16} />}
<div className='font-semibold text-sm'>
Function Call: <code className='bg-gray-100 dark:bg-neutral-800 px-2 py-0.5 rounded font-mono'>{toolCall.function.name}</code>
</div>
</div>
</div>
<div className='flex flex-col gap-2'>
<ExpandableContent label='Params' content={toolCall.function.arguments} expanded={false} />
@ -292,7 +257,6 @@ export function Messages({
messages,
toolCallResults,
loadingAssistantResponse,
loadingUserResponse,
workflow,
testProfile = null,
systemMessage,
@ -302,7 +266,6 @@ export function Messages({
messages: z.infer<typeof apiV1.ChatMessage>[];
toolCallResults: Record<string, z.infer<typeof apiV1.ToolMessage>>;
loadingAssistantResponse: boolean;
loadingUserResponse: boolean;
workflow: z.infer<typeof Workflow>;
testProfile: z.infer<typeof TestProfile> | null;
systemMessage: string | undefined;
@ -314,7 +277,7 @@ export function Messages({
// scroll to bottom on new messages
useEffect(() => {
messagesEndRef.current?.scrollIntoView({ behavior: "smooth" })
}, [messages, loadingAssistantResponse, loadingUserResponse]);
}, [messages, loadingAssistantResponse]);
return <div className="grow pt-4 overflow-auto">
<div className="max-w-[768px] mx-auto flex flex-col gap-8">
@ -368,7 +331,6 @@ export function Messages({
return <></>;
})}
{loadingAssistantResponse && <AssistantMessageLoading key="assistant-loading" />}
{loadingUserResponse && <UserMessageLoading key="user-loading" />}
<div ref={messagesEndRef} />
</div>
</div>;

View file

@ -9,17 +9,15 @@ import { WorkflowSelector } from "./workflow_selector";
import { Spinner } from "@heroui/react";
import { cloneWorkflow, createWorkflow, fetchPublishedWorkflowId, fetchWorkflow } from "../../../actions/workflow_actions";
import { listDataSources } from "../../../actions/datasource_actions";
import { listMcpServers } from "@/app/actions/mcp_actions";
import { getProjectConfig } from "@/app/actions/project_actions";
export function App({
projectId,
useRag,
mcpServerUrls,
toolWebhookUrl,
}: {
projectId: string;
useRag: boolean;
mcpServerUrls: Array<z.infer<typeof MCPServer>>;
toolWebhookUrl: string;
}) {
const [selectorKey, setSelectorKey] = useState(0);
const [workflow, setWorkflow] = useState<WithStringId<z.infer<typeof Workflow>> | null>(null);
@ -27,17 +25,23 @@ export function App({
const [dataSources, setDataSources] = useState<WithStringId<z.infer<typeof DataSource>>[] | null>(null);
const [loading, setLoading] = useState(false);
const [autoSelectIfOnlyOneWorkflow, setAutoSelectIfOnlyOneWorkflow] = useState(true);
const [mcpServerUrls, setMcpServerUrls] = useState<Array<z.infer<typeof MCPServer>>>([]);
const [toolWebhookUrl, setToolWebhookUrl] = useState<string>('');
const handleSelect = useCallback(async (workflowId: string) => {
setLoading(true);
const workflow = await fetchWorkflow(projectId, workflowId);
const publishedWorkflowId = await fetchPublishedWorkflowId(projectId);
const dataSources = await listDataSources(projectId);
const mcpServers = await listMcpServers(projectId);
const projectConfig = await getProjectConfig(projectId);
// Store the selected workflow ID in local storage
localStorage.setItem(`lastWorkflowId_${projectId}`, workflowId);
setWorkflow(workflow);
setPublishedWorkflowId(publishedWorkflowId);
setDataSources(dataSources);
setMcpServerUrls(mcpServers);
setToolWebhookUrl(projectConfig.webhookUrl ?? '');
setLoading(false);
}, [projectId]);

View file

@ -13,18 +13,16 @@ export default async function Page({
}: {
params: { projectId: string };
}) {
console.log('->>> workflow page being rendered');
const project = await projectsCollection.findOne({
_id: params.projectId,
});
if (!project) {
notFound();
}
const toolWebhookUrl = project.webhookUrl ?? '';
return <App
projectId={params.projectId}
useRag={USE_RAG}
mcpServerUrls={project.mcpServers ?? []}
toolWebhookUrl={toolWebhookUrl}
/>;
}

View file

@ -80,6 +80,7 @@ python-docx = "^1.1.2"
python-dotenv = "^1.0.1"
pytz = "^2024.2"
qdrant-client = "*"
Quart = "^0.20.0"
RapidFuzz = "^3.11.0"
redis = "^5.2.1"
requests = "^2.32.3"

View file

@ -66,6 +66,7 @@ python-docx==1.1.2
python-dotenv==1.0.1
pytz==2024.2
qdrant-client
Quart==0.20.0
RapidFuzz==3.11.0
redis==5.2.1
requests==2.32.3

View file

@ -1,13 +1,13 @@
from flask import Flask, request, jsonify, Response
from quart import Quart, request, jsonify, Response
from datetime import datetime
from functools import wraps
import os
import redis
import uuid
import json
import asyncio
from hypercorn.config import Config
from hypercorn.asyncio import serve
import asyncio
from src.graph.core import run_turn, run_turn_streamed
from src.graph.tools import RAG_TOOL, CLOSE_CHAT_TOOL
@ -17,19 +17,34 @@ from pprint import pprint
logger = common_logger
redis_client = redis.from_url(os.environ.get('REDIS_URL', 'redis://localhost:6379'))
app = Flask(__name__)
app = Quart(__name__)
# filter out agent transfer messages using a function
def is_agent_transfer_message(msg):
if (msg.get("role") == "assistant" and
msg.get("content") is None and
msg.get("tool_calls") is not None and
len(msg.get("tool_calls")) > 0 and
msg.get("tool_calls")[0].get("function").get("name") == "transfer_to_agent"):
return True
if (msg.get("role") == "tool" and
msg.get("tool_calls") is None and
msg.get("tool_call_id") is not None and
msg.get("tool_name") == "transfer_to_agent"):
return True
return False
@app.route("/health", methods=["GET"])
def health():
async def health():
return jsonify({"status": "ok"})
@app.route("/")
def home():
async def home():
return "Hello, World!"
def require_api_key(f):
@wraps(f)
def decorated(*args, **kwargs):
async def decorated(*args, **kwargs):
auth_header = request.headers.get('Authorization')
if not auth_header or not auth_header.startswith('Bearer '):
return jsonify({'error': 'Missing or invalid authorization header'}), 401
@ -39,16 +54,16 @@ def require_api_key(f):
if actual and token != actual:
return jsonify({'error': 'Invalid API key'}), 403
return f(*args, **kwargs)
return await f(*args, **kwargs)
return decorated
@app.route("/chat", methods=["POST"])
@require_api_key
def chat():
async def chat():
logger.info('='*100)
logger.info(f"{'*'*100}Running server mode{'*'*100}")
try:
data = request.get_json()
data = await request.get_json()
logger.info('Complete request:')
logger.info(data)
logger.info('-'*100)
@ -56,9 +71,12 @@ def chat():
start_time = datetime.now()
config = read_json_from_file("./configs/default_config.json")
# filter out agent transfer messages
input_messages = [msg for msg in data.get("messages", []) if not is_agent_transfer_message(msg)]
logger.info('Beginning turn')
resp_messages, resp_tokens_used, resp_state = run_turn(
messages=data.get("messages", []),
messages=input_messages,
start_agent_name=data.get("startAgent", ""),
agent_configs=data.get("agents", []),
tool_configs=data.get("tools", []),
@ -94,19 +112,27 @@ def chat():
@app.route("/chat_stream_init", methods=["POST"])
@require_api_key
def chat_stream_init():
async def chat_stream_init():
# create a uuid for the stream
stream_id = str(uuid.uuid4())
# store the request data in redis with 10 minute TTL
data = request.get_json()
data = await request.get_json()
redis_client.setex(f"stream_request_{stream_id}", 600, json.dumps(data))
return jsonify({"stream_id": stream_id})
print('* stream init'*200)
return jsonify({"streamId": stream_id})
def format_sse(data: dict, event: str = None) -> str:
msg = f"data: {json.dumps(data)}\n\n"
if event is not None:
msg = f"event: {event}\n{msg}"
return msg
@app.route("/chat_stream/<stream_id>", methods=["GET"])
@require_api_key
def chat_stream(stream_id):
async def chat_stream(stream_id):
# get the request data from redis
request_data = redis_client.get(f"stream_request_{stream_id}")
if not request_data:
@ -114,17 +140,18 @@ def chat_stream(stream_id):
request_data = json.loads(request_data)
config = read_json_from_file("./configs/default_config.json")
# filter out agent transfer messages
input_messages = [msg for msg in request_data["messages"] if not is_agent_transfer_message(msg)]
# Preprocess messages to handle null content and role issues
for msg in request_data["messages"]:
# Handle null content in assistant messages with tool calls
for msg in input_messages:
if (msg.get("role") == "assistant" and
msg.get("content") is None and
msg.get("tool_calls") is not None and
len(msg.get("tool_calls")) > 0):
msg["content"] = "Calling tool"
# Handle role issues
if msg.get("role") == "tool":
msg["role"] = "developer"
elif not msg.get("role"):
@ -135,12 +162,11 @@ def chat_stream(stream_id):
print('*'*200)
pprint(request_data)
print('='*200)
async def process_stream():
async def generate():
try:
async for event_type, event_data in run_turn_streamed(
messages=request_data.get("messages", []),
messages=input_messages,
start_agent_name=request_data.get("startAgent", ""),
agent_configs=request_data.get("agents", []),
tool_configs=request_data.get("tools", []),
@ -153,43 +179,16 @@ def chat_stream(stream_id):
print('*'*200)
print("Yielding message:")
print('*'*200)
to_yield = f"event: message\ndata: {json.dumps(event_data)}\n\n"
print(to_yield)
print('='*200)
yield to_yield
yield format_sse(event_data, "message")
elif event_type == 'done':
print('*'*200)
print("Yielding done:")
print('*'*200)
to_yield = f"event: done\ndata: {json.dumps(event_data)}\n\n"
print(to_yield)
print('='*200)
yield to_yield
yield format_sse(event_data, "done")
except Exception as e:
logger.error(f"Streaming error: {str(e)}")
yield f"event: error\ndata: {json.dumps({'error': str(e)})}\n\n"
def generate():
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
async def get_all_chunks():
chunks = []
async for chunk in process_stream():
chunks.append(chunk)
return chunks
chunks = loop.run_until_complete(get_all_chunks())
for chunk in chunks:
yield chunk
except Exception as e:
logger.error(f"Error in generate: {e}")
raise
finally:
loop.close()
yield format_sse({"error": str(e)}, "error")
return Response(generate(), mimetype='text/event-stream')

View file

@ -1,6 +1,7 @@
from copy import deepcopy
from datetime import datetime
import json
import uuid
import logging
from .helpers.access import (
get_agent_by_name,
@ -285,16 +286,45 @@ async def run_turn_streamed(
# Update current agent when it changes
elif event.type == "agent_updated_stream_event":
current_agent = event.new_agent
tool_call_id = str(uuid.uuid4())
# yield the transfer invocation
message = {
'content': f"Agent changed to {current_agent.name}",
'content': None,
'role': 'assistant',
'sender': current_agent.name,
'tool_calls': None,
'tool_calls': [{
'function': {
'name': 'transfer_to_agent',
'arguments': json.dumps({
'assistant': event.new_agent.name
})
},
'id': tool_call_id,
'type': 'function'
}],
'tool_call_id': None,
'tool_name': None,
'response_type': 'internal'
}
print("Yielding message: ", message)
yield ('message', message)
# yield the transfer result
message = {
'content': json.dumps({
'assistant': event.new_agent.name
}),
'role': 'tool',
'sender': None,
'tool_calls': None,
'tool_call_id': tool_call_id,
'tool_name': 'transfer_to_agent',
}
print("Yielding message: ", message)
yield ('message', message)
current_agent = event.new_agent
continue
# Handle run items (tools, messages, etc)