mirror of
https://github.com/rowboatlabs/rowboat.git
synced 2026-04-26 17:06:23 +02:00
Next.js changes for playground streaming
This commit is contained in:
parent
24efe0e887
commit
77b53696b6
14 changed files with 290 additions and 160 deletions
|
|
@ -1,13 +1,13 @@
|
|||
from flask import Flask, request, jsonify, Response
|
||||
from quart import Quart, request, jsonify, Response
|
||||
from datetime import datetime
|
||||
from functools import wraps
|
||||
import os
|
||||
import redis
|
||||
import uuid
|
||||
import json
|
||||
import asyncio
|
||||
from hypercorn.config import Config
|
||||
from hypercorn.asyncio import serve
|
||||
import asyncio
|
||||
|
||||
from src.graph.core import run_turn, run_turn_streamed
|
||||
from src.graph.tools import RAG_TOOL, CLOSE_CHAT_TOOL
|
||||
|
|
@ -17,19 +17,34 @@ from pprint import pprint
|
|||
|
||||
logger = common_logger
|
||||
redis_client = redis.from_url(os.environ.get('REDIS_URL', 'redis://localhost:6379'))
|
||||
app = Flask(__name__)
|
||||
app = Quart(__name__)
|
||||
|
||||
# filter out agent transfer messages using a function
|
||||
def is_agent_transfer_message(msg):
|
||||
if (msg.get("role") == "assistant" and
|
||||
msg.get("content") is None and
|
||||
msg.get("tool_calls") is not None and
|
||||
len(msg.get("tool_calls")) > 0 and
|
||||
msg.get("tool_calls")[0].get("function").get("name") == "transfer_to_agent"):
|
||||
return True
|
||||
if (msg.get("role") == "tool" and
|
||||
msg.get("tool_calls") is None and
|
||||
msg.get("tool_call_id") is not None and
|
||||
msg.get("tool_name") == "transfer_to_agent"):
|
||||
return True
|
||||
return False
|
||||
|
||||
@app.route("/health", methods=["GET"])
|
||||
def health():
|
||||
async def health():
|
||||
return jsonify({"status": "ok"})
|
||||
|
||||
@app.route("/")
|
||||
def home():
|
||||
async def home():
|
||||
return "Hello, World!"
|
||||
|
||||
def require_api_key(f):
|
||||
@wraps(f)
|
||||
def decorated(*args, **kwargs):
|
||||
async def decorated(*args, **kwargs):
|
||||
auth_header = request.headers.get('Authorization')
|
||||
if not auth_header or not auth_header.startswith('Bearer '):
|
||||
return jsonify({'error': 'Missing or invalid authorization header'}), 401
|
||||
|
|
@ -39,16 +54,16 @@ def require_api_key(f):
|
|||
if actual and token != actual:
|
||||
return jsonify({'error': 'Invalid API key'}), 403
|
||||
|
||||
return f(*args, **kwargs)
|
||||
return await f(*args, **kwargs)
|
||||
return decorated
|
||||
|
||||
@app.route("/chat", methods=["POST"])
|
||||
@require_api_key
|
||||
def chat():
|
||||
async def chat():
|
||||
logger.info('='*100)
|
||||
logger.info(f"{'*'*100}Running server mode{'*'*100}")
|
||||
try:
|
||||
data = request.get_json()
|
||||
data = await request.get_json()
|
||||
logger.info('Complete request:')
|
||||
logger.info(data)
|
||||
logger.info('-'*100)
|
||||
|
|
@ -56,9 +71,12 @@ def chat():
|
|||
start_time = datetime.now()
|
||||
config = read_json_from_file("./configs/default_config.json")
|
||||
|
||||
# filter out agent transfer messages
|
||||
input_messages = [msg for msg in data.get("messages", []) if not is_agent_transfer_message(msg)]
|
||||
|
||||
logger.info('Beginning turn')
|
||||
resp_messages, resp_tokens_used, resp_state = run_turn(
|
||||
messages=data.get("messages", []),
|
||||
messages=input_messages,
|
||||
start_agent_name=data.get("startAgent", ""),
|
||||
agent_configs=data.get("agents", []),
|
||||
tool_configs=data.get("tools", []),
|
||||
|
|
@ -94,19 +112,27 @@ def chat():
|
|||
|
||||
@app.route("/chat_stream_init", methods=["POST"])
|
||||
@require_api_key
|
||||
def chat_stream_init():
|
||||
async def chat_stream_init():
|
||||
# create a uuid for the stream
|
||||
stream_id = str(uuid.uuid4())
|
||||
|
||||
# store the request data in redis with 10 minute TTL
|
||||
data = request.get_json()
|
||||
data = await request.get_json()
|
||||
redis_client.setex(f"stream_request_{stream_id}", 600, json.dumps(data))
|
||||
|
||||
return jsonify({"stream_id": stream_id})
|
||||
print('* stream init'*200)
|
||||
|
||||
return jsonify({"streamId": stream_id})
|
||||
|
||||
def format_sse(data: dict, event: str = None) -> str:
|
||||
msg = f"data: {json.dumps(data)}\n\n"
|
||||
if event is not None:
|
||||
msg = f"event: {event}\n{msg}"
|
||||
return msg
|
||||
|
||||
@app.route("/chat_stream/<stream_id>", methods=["GET"])
|
||||
@require_api_key
|
||||
def chat_stream(stream_id):
|
||||
async def chat_stream(stream_id):
|
||||
# get the request data from redis
|
||||
request_data = redis_client.get(f"stream_request_{stream_id}")
|
||||
if not request_data:
|
||||
|
|
@ -114,17 +140,18 @@ def chat_stream(stream_id):
|
|||
|
||||
request_data = json.loads(request_data)
|
||||
config = read_json_from_file("./configs/default_config.json")
|
||||
|
||||
# filter out agent transfer messages
|
||||
input_messages = [msg for msg in request_data["messages"] if not is_agent_transfer_message(msg)]
|
||||
|
||||
# Preprocess messages to handle null content and role issues
|
||||
for msg in request_data["messages"]:
|
||||
# Handle null content in assistant messages with tool calls
|
||||
for msg in input_messages:
|
||||
if (msg.get("role") == "assistant" and
|
||||
msg.get("content") is None and
|
||||
msg.get("tool_calls") is not None and
|
||||
len(msg.get("tool_calls")) > 0):
|
||||
msg["content"] = "Calling tool"
|
||||
|
||||
# Handle role issues
|
||||
if msg.get("role") == "tool":
|
||||
msg["role"] = "developer"
|
||||
elif not msg.get("role"):
|
||||
|
|
@ -135,12 +162,11 @@ def chat_stream(stream_id):
|
|||
print('*'*200)
|
||||
pprint(request_data)
|
||||
print('='*200)
|
||||
|
||||
|
||||
async def process_stream():
|
||||
async def generate():
|
||||
try:
|
||||
async for event_type, event_data in run_turn_streamed(
|
||||
messages=request_data.get("messages", []),
|
||||
messages=input_messages,
|
||||
start_agent_name=request_data.get("startAgent", ""),
|
||||
agent_configs=request_data.get("agents", []),
|
||||
tool_configs=request_data.get("tools", []),
|
||||
|
|
@ -153,43 +179,16 @@ def chat_stream(stream_id):
|
|||
print('*'*200)
|
||||
print("Yielding message:")
|
||||
print('*'*200)
|
||||
to_yield = f"event: message\ndata: {json.dumps(event_data)}\n\n"
|
||||
print(to_yield)
|
||||
print('='*200)
|
||||
yield to_yield
|
||||
yield format_sse(event_data, "message")
|
||||
elif event_type == 'done':
|
||||
print('*'*200)
|
||||
print("Yielding done:")
|
||||
print('*'*200)
|
||||
to_yield = f"event: done\ndata: {json.dumps(event_data)}\n\n"
|
||||
print(to_yield)
|
||||
print('='*200)
|
||||
yield to_yield
|
||||
yield format_sse(event_data, "done")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Streaming error: {str(e)}")
|
||||
yield f"event: error\ndata: {json.dumps({'error': str(e)})}\n\n"
|
||||
|
||||
def generate():
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
async def get_all_chunks():
|
||||
chunks = []
|
||||
async for chunk in process_stream():
|
||||
chunks.append(chunk)
|
||||
return chunks
|
||||
|
||||
chunks = loop.run_until_complete(get_all_chunks())
|
||||
for chunk in chunks:
|
||||
yield chunk
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in generate: {e}")
|
||||
raise
|
||||
finally:
|
||||
loop.close()
|
||||
yield format_sse({"error": str(e)}, "error")
|
||||
|
||||
return Response(generate(), mimetype='text/event-stream')
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
from copy import deepcopy
|
||||
from datetime import datetime
|
||||
|
||||
import json
|
||||
import uuid
|
||||
import logging
|
||||
from .helpers.access import (
|
||||
get_agent_by_name,
|
||||
|
|
@ -285,16 +286,45 @@ async def run_turn_streamed(
|
|||
# Update current agent when it changes
|
||||
elif event.type == "agent_updated_stream_event":
|
||||
current_agent = event.new_agent
|
||||
tool_call_id = str(uuid.uuid4())
|
||||
|
||||
# yield the transfer invocation
|
||||
message = {
|
||||
'content': f"Agent changed to {current_agent.name}",
|
||||
'content': None,
|
||||
'role': 'assistant',
|
||||
'sender': current_agent.name,
|
||||
'tool_calls': None,
|
||||
'tool_calls': [{
|
||||
'function': {
|
||||
'name': 'transfer_to_agent',
|
||||
'arguments': json.dumps({
|
||||
'assistant': event.new_agent.name
|
||||
})
|
||||
},
|
||||
'id': tool_call_id,
|
||||
'type': 'function'
|
||||
}],
|
||||
'tool_call_id': None,
|
||||
'tool_name': None,
|
||||
'response_type': 'internal'
|
||||
}
|
||||
print("Yielding message: ", message)
|
||||
yield ('message', message)
|
||||
|
||||
# yield the transfer result
|
||||
message = {
|
||||
'content': json.dumps({
|
||||
'assistant': event.new_agent.name
|
||||
}),
|
||||
'role': 'tool',
|
||||
'sender': None,
|
||||
'tool_calls': None,
|
||||
'tool_call_id': tool_call_id,
|
||||
'tool_name': 'transfer_to_agent',
|
||||
}
|
||||
print("Yielding message: ", message)
|
||||
yield ('message', message)
|
||||
|
||||
current_agent = event.new_agent
|
||||
continue
|
||||
|
||||
# Handle run items (tools, messages, etc)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue