Next.js changes for playground streaming

This commit is contained in:
ramnique 2025-03-25 01:42:22 +05:30 committed by Ramnique Singh
parent 24efe0e887
commit 77b53696b6
14 changed files with 290 additions and 160 deletions

View file

@ -1,13 +1,13 @@
from flask import Flask, request, jsonify, Response
from quart import Quart, request, jsonify, Response
from datetime import datetime
from functools import wraps
import os
import redis
import uuid
import json
import asyncio
from hypercorn.config import Config
from hypercorn.asyncio import serve
import asyncio
from src.graph.core import run_turn, run_turn_streamed
from src.graph.tools import RAG_TOOL, CLOSE_CHAT_TOOL
@ -17,19 +17,34 @@ from pprint import pprint
logger = common_logger
redis_client = redis.from_url(os.environ.get('REDIS_URL', 'redis://localhost:6379'))
app = Flask(__name__)
app = Quart(__name__)
# filter out agent transfer messages using a function
def is_agent_transfer_message(msg):
if (msg.get("role") == "assistant" and
msg.get("content") is None and
msg.get("tool_calls") is not None and
len(msg.get("tool_calls")) > 0 and
msg.get("tool_calls")[0].get("function").get("name") == "transfer_to_agent"):
return True
if (msg.get("role") == "tool" and
msg.get("tool_calls") is None and
msg.get("tool_call_id") is not None and
msg.get("tool_name") == "transfer_to_agent"):
return True
return False
@app.route("/health", methods=["GET"])
def health():
async def health():
return jsonify({"status": "ok"})
@app.route("/")
def home():
async def home():
return "Hello, World!"
def require_api_key(f):
@wraps(f)
def decorated(*args, **kwargs):
async def decorated(*args, **kwargs):
auth_header = request.headers.get('Authorization')
if not auth_header or not auth_header.startswith('Bearer '):
return jsonify({'error': 'Missing or invalid authorization header'}), 401
@ -39,16 +54,16 @@ def require_api_key(f):
if actual and token != actual:
return jsonify({'error': 'Invalid API key'}), 403
return f(*args, **kwargs)
return await f(*args, **kwargs)
return decorated
@app.route("/chat", methods=["POST"])
@require_api_key
def chat():
async def chat():
logger.info('='*100)
logger.info(f"{'*'*100}Running server mode{'*'*100}")
try:
data = request.get_json()
data = await request.get_json()
logger.info('Complete request:')
logger.info(data)
logger.info('-'*100)
@ -56,9 +71,12 @@ def chat():
start_time = datetime.now()
config = read_json_from_file("./configs/default_config.json")
# filter out agent transfer messages
input_messages = [msg for msg in data.get("messages", []) if not is_agent_transfer_message(msg)]
logger.info('Beginning turn')
resp_messages, resp_tokens_used, resp_state = run_turn(
messages=data.get("messages", []),
messages=input_messages,
start_agent_name=data.get("startAgent", ""),
agent_configs=data.get("agents", []),
tool_configs=data.get("tools", []),
@ -94,19 +112,27 @@ def chat():
@app.route("/chat_stream_init", methods=["POST"])
@require_api_key
def chat_stream_init():
async def chat_stream_init():
# create a uuid for the stream
stream_id = str(uuid.uuid4())
# store the request data in redis with 10 minute TTL
data = request.get_json()
data = await request.get_json()
redis_client.setex(f"stream_request_{stream_id}", 600, json.dumps(data))
return jsonify({"stream_id": stream_id})
print('* stream init'*200)
return jsonify({"streamId": stream_id})
def format_sse(data: dict, event: str = None) -> str:
msg = f"data: {json.dumps(data)}\n\n"
if event is not None:
msg = f"event: {event}\n{msg}"
return msg
@app.route("/chat_stream/<stream_id>", methods=["GET"])
@require_api_key
def chat_stream(stream_id):
async def chat_stream(stream_id):
# get the request data from redis
request_data = redis_client.get(f"stream_request_{stream_id}")
if not request_data:
@ -114,17 +140,18 @@ def chat_stream(stream_id):
request_data = json.loads(request_data)
config = read_json_from_file("./configs/default_config.json")
# filter out agent transfer messages
input_messages = [msg for msg in request_data["messages"] if not is_agent_transfer_message(msg)]
# Preprocess messages to handle null content and role issues
for msg in request_data["messages"]:
# Handle null content in assistant messages with tool calls
for msg in input_messages:
if (msg.get("role") == "assistant" and
msg.get("content") is None and
msg.get("tool_calls") is not None and
len(msg.get("tool_calls")) > 0):
msg["content"] = "Calling tool"
# Handle role issues
if msg.get("role") == "tool":
msg["role"] = "developer"
elif not msg.get("role"):
@ -135,12 +162,11 @@ def chat_stream(stream_id):
print('*'*200)
pprint(request_data)
print('='*200)
async def process_stream():
async def generate():
try:
async for event_type, event_data in run_turn_streamed(
messages=request_data.get("messages", []),
messages=input_messages,
start_agent_name=request_data.get("startAgent", ""),
agent_configs=request_data.get("agents", []),
tool_configs=request_data.get("tools", []),
@ -153,43 +179,16 @@ def chat_stream(stream_id):
print('*'*200)
print("Yielding message:")
print('*'*200)
to_yield = f"event: message\ndata: {json.dumps(event_data)}\n\n"
print(to_yield)
print('='*200)
yield to_yield
yield format_sse(event_data, "message")
elif event_type == 'done':
print('*'*200)
print("Yielding done:")
print('*'*200)
to_yield = f"event: done\ndata: {json.dumps(event_data)}\n\n"
print(to_yield)
print('='*200)
yield to_yield
yield format_sse(event_data, "done")
except Exception as e:
logger.error(f"Streaming error: {str(e)}")
yield f"event: error\ndata: {json.dumps({'error': str(e)})}\n\n"
def generate():
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
async def get_all_chunks():
chunks = []
async for chunk in process_stream():
chunks.append(chunk)
return chunks
chunks = loop.run_until_complete(get_all_chunks())
for chunk in chunks:
yield chunk
except Exception as e:
logger.error(f"Error in generate: {e}")
raise
finally:
loop.close()
yield format_sse({"error": str(e)}, "error")
return Response(generate(), mimetype='text/event-stream')

View file

@ -1,6 +1,7 @@
from copy import deepcopy
from datetime import datetime
import json
import uuid
import logging
from .helpers.access import (
get_agent_by_name,
@ -285,16 +286,45 @@ async def run_turn_streamed(
# Update current agent when it changes
elif event.type == "agent_updated_stream_event":
current_agent = event.new_agent
tool_call_id = str(uuid.uuid4())
# yield the transfer invocation
message = {
'content': f"Agent changed to {current_agent.name}",
'content': None,
'role': 'assistant',
'sender': current_agent.name,
'tool_calls': None,
'tool_calls': [{
'function': {
'name': 'transfer_to_agent',
'arguments': json.dumps({
'assistant': event.new_agent.name
})
},
'id': tool_call_id,
'type': 'function'
}],
'tool_call_id': None,
'tool_name': None,
'response_type': 'internal'
}
print("Yielding message: ", message)
yield ('message', message)
# yield the transfer result
message = {
'content': json.dumps({
'assistant': event.new_agent.name
}),
'role': 'tool',
'sender': None,
'tool_calls': None,
'tool_call_id': tool_call_id,
'tool_name': 'transfer_to_agent',
}
print("Yielding message: ", message)
yield ('message', message)
current_agent = event.new_agent
continue
# Handle run items (tools, messages, etc)