rowboat/apps/rowboat_agents/src/app/main.py

186 lines
6.4 KiB
Python
Raw Normal View History

2025-04-05 00:37:20 +05:30
import traceback
from quart import Quart, request, jsonify, Response
2025-03-25 15:37:51 +05:30
from functools import wraps
import os
import json
from hypercorn.config import Config
from hypercorn.asyncio import serve
import asyncio
2025-03-25 15:37:51 +05:30
from src.graph.core import run_turn_streamed
from src.utils.common import read_json_from_file
2025-03-25 15:37:51 +05:30
app = Quart(__name__)
master_config = read_json_from_file("./configs/default_config.json")
print("Master config:", master_config)
# Get environment variables with defaults
ENABLE_TRACING = False
try:
ENABLE_TRACING = os.environ.get('ENABLE_TRACING').lower() == 'true'
except Exception as e:
print(f"Error getting ENABLE_TRACING: {e}, using default of False")
# filter out agent transfer messages using a function
def is_agent_transfer_message(msg):
if (msg.get("role") == "assistant" and
msg.get("content") is None and
msg.get("tool_calls") is not None and
len(msg.get("tool_calls")) > 0 and
msg.get("tool_calls")[0].get("function").get("name") == "transfer_to_agent"):
return True
if (msg.get("role") == "tool" and
msg.get("tool_calls") is None and
msg.get("tool_call_id") is not None and
msg.get("tool_name") == "transfer_to_agent"):
return True
return False
2025-03-25 15:37:51 +05:30
@app.route("/health", methods=["GET"])
async def health():
2025-03-25 15:37:51 +05:30
return jsonify({"status": "ok"})
@app.route("/")
async def home():
2025-03-25 15:37:51 +05:30
return "Hello, World!"
def require_api_key(f):
@wraps(f)
async def decorated(*args, **kwargs):
2025-03-25 15:37:51 +05:30
auth_header = request.headers.get('Authorization')
if not auth_header or not auth_header.startswith('Bearer '):
return jsonify({'error': 'Missing or invalid authorization header'}), 401
token = auth_header.split('Bearer ')[1]
actual = os.environ.get('API_KEY', '').strip()
if actual and token != actual:
return jsonify({'error': 'Invalid API key'}), 403
return await f(*args, **kwargs)
2025-03-25 15:37:51 +05:30
return decorated
@app.route("/chat", methods=["POST"])
@require_api_key
async def chat():
print('='*100)
print(f"{'*'*100}Running server mode{'*'*100}")
2025-03-25 15:37:51 +05:30
try:
request_data = await request.get_json()
2025-04-05 00:37:20 +05:30
print("Request:", json.dumps(request_data))
2025-03-25 15:37:51 +05:30
# filter out agent transfer messages
input_messages = [msg for msg in request_data["messages"] if not is_agent_transfer_message(msg)]
# Preprocess messages to handle null content and role issues
for msg in input_messages:
if (msg.get("role") == "assistant" and
msg.get("content") is None and
msg.get("tool_calls") is not None and
len(msg.get("tool_calls")) > 0):
msg["content"] = "Calling tool"
if msg.get("role") == "tool":
msg["role"] = "developer"
elif not msg.get("role"):
msg["role"] = "user"
data = request_data
2025-04-05 00:37:20 +05:30
messages = []
final_state = {}
async for event_type, event_data in run_turn_streamed(
messages=input_messages,
2025-03-25 15:37:51 +05:30
start_agent_name=data.get("startAgent", ""),
agent_configs=data.get("agents", []),
tool_configs=data.get("tools", []),
prompt_configs=data.get("prompts", []),
start_turn_with_start_agent=master_config.get("start_turn_with_start_agent", False),
2025-03-25 15:37:51 +05:30
state=data.get("state", {}),
complete_request=data,
enable_tracing=ENABLE_TRACING
2025-04-05 00:37:20 +05:30
):
if event_type == 'message':
messages.append(event_data)
elif event_type == 'done':
final_state = event_data['state']
2025-03-25 15:37:51 +05:30
out = {
2025-04-05 00:37:20 +05:30
"messages": messages,
"state": final_state,
2025-03-25 15:37:51 +05:30
}
print("Output:")
2025-03-25 15:37:51 +05:30
for k, v in out.items():
print(f"{k}: {v}")
print('*'*100)
2025-03-25 15:37:51 +05:30
return jsonify(out)
except Exception as e:
2025-04-05 00:37:20 +05:30
print(traceback.format_exc())
print(f"Error: {str(e)}")
2025-03-25 15:37:51 +05:30
return jsonify({"error": str(e)}), 500
def format_sse(data: dict, event: str = None) -> str:
msg = f"data: {json.dumps(data)}\n\n"
if event is not None:
msg = f"event: {event}\n{msg}"
return msg
2025-03-25 15:37:51 +05:30
@app.route("/chat_stream", methods=["POST"])
2025-03-25 15:37:51 +05:30
@require_api_key
async def chat_stream():
# get the request data from the request
request_data = await request.get_data()
2025-03-25 15:37:51 +05:30
2025-04-05 00:37:20 +05:30
print("Request:", request_data.decode('utf-8'))
2025-03-25 15:37:51 +05:30
request_data = json.loads(request_data)
# filter out agent transfer messages
input_messages = [msg for msg in request_data["messages"] if not is_agent_transfer_message(msg)]
2025-03-25 15:37:51 +05:30
# Preprocess messages to handle null content and role issues
for msg in input_messages:
if (msg.get("role") == "assistant" and
msg.get("content") is None and
msg.get("tool_calls") is not None and
2025-03-25 15:37:51 +05:30
len(msg.get("tool_calls")) > 0):
msg["content"] = "Calling tool"
2025-03-25 15:37:51 +05:30
if msg.get("role") == "tool":
msg["role"] = "developer"
elif not msg.get("role"):
msg["role"] = "user"
async def generate():
print("Running generate() in server")
2025-03-25 15:37:51 +05:30
try:
async for event_type, event_data in run_turn_streamed(
messages=input_messages,
2025-03-25 15:37:51 +05:30
start_agent_name=request_data.get("startAgent", ""),
agent_configs=request_data.get("agents", []),
tool_configs=request_data.get("tools", []),
prompt_configs=request_data.get("prompts", []),
start_turn_with_start_agent=master_config.get("start_turn_with_start_agent", False),
2025-03-25 15:37:51 +05:30
state=request_data.get("state", {}),
complete_request=request_data,
enable_tracing=ENABLE_TRACING
2025-03-25 15:37:51 +05:30
):
if event_type == 'message':
yield format_sse(event_data, "message")
2025-03-25 15:37:51 +05:30
elif event_type == 'done':
yield format_sse(event_data, "done")
elif event_type == 'error':
yield format_sse(event_data, " error")
2025-03-25 15:37:51 +05:30
except Exception as e:
print(f"Streaming error: {str(e)}")
yield format_sse({"error": str(e)}, "error")
2025-03-25 15:37:51 +05:30
return Response(generate(), mimetype='text/event-stream')
if __name__ == "__main__":
print("Starting async server...")
config = Config()
config.bind = ["0.0.0.0:4040"]
asyncio.run(serve(app, config))