mirror of
https://github.com/rowboatlabs/rowboat.git
synced 2026-05-01 11:26:23 +02:00
Enable internal and user-facing agents to build pipelines
This commit is contained in:
parent
1246ea47b9
commit
e59a8b75cf
24 changed files with 2100 additions and 1376 deletions
|
|
@ -10,11 +10,11 @@ import asyncio
|
|||
|
||||
from src.graph.core import run_turn_streamed
|
||||
from src.graph.tools import RAG_TOOL, CLOSE_CHAT_TOOL
|
||||
from src.utils.common import common_logger, read_json_from_file
|
||||
from src.utils.common import read_json_from_file
|
||||
|
||||
logger = common_logger
|
||||
app = Quart(__name__)
|
||||
config = read_json_from_file("./configs/default_config.json")
|
||||
master_config = read_json_from_file("./configs/default_config.json")
|
||||
print("Master config:", master_config)
|
||||
|
||||
# filter out agent transfer messages using a function
|
||||
def is_agent_transfer_message(msg):
|
||||
|
|
@ -57,12 +57,15 @@ def require_api_key(f):
|
|||
@app.route("/chat", methods=["POST"])
|
||||
@require_api_key
|
||||
async def chat():
|
||||
logger.info('='*100)
|
||||
logger.info(f"{'*'*100}Running server mode{'*'*100}")
|
||||
print('='*100)
|
||||
print(f"{'*'*100}Running server mode{'*'*100}")
|
||||
try:
|
||||
request_data = await request.get_json()
|
||||
print("Request:", json.dumps(request_data))
|
||||
|
||||
# Add enable_tracing from master_config to request_data
|
||||
request_data["enable_tracing"] = master_config.get("enable_tracing", False)
|
||||
|
||||
# filter out agent transfer messages
|
||||
input_messages = [msg for msg in request_data["messages"] if not is_agent_transfer_message(msg)]
|
||||
|
||||
|
|
@ -82,7 +85,6 @@ async def chat():
|
|||
data = request_data
|
||||
messages = []
|
||||
final_state = {}
|
||||
# tokens_used = 0
|
||||
|
||||
async for event_type, event_data in run_turn_streamed(
|
||||
messages=input_messages,
|
||||
|
|
@ -90,7 +92,8 @@ async def chat():
|
|||
agent_configs=data.get("agents", []),
|
||||
tool_configs=data.get("tools", []),
|
||||
prompt_configs=data.get("prompts", []),
|
||||
start_turn_with_start_agent=config.get("start_turn_with_start_agent", False),
|
||||
start_turn_with_start_agent=master_config.get("start_turn_with_start_agent", False),
|
||||
max_calls_per_child_agent=master_config.get("max_calls_per_child_agent", 1),
|
||||
state=data.get("state", {}),
|
||||
additional_tool_configs=[RAG_TOOL, CLOSE_CHAT_TOOL],
|
||||
complete_request=data
|
||||
|
|
@ -99,23 +102,22 @@ async def chat():
|
|||
messages.append(event_data)
|
||||
elif event_type == 'done':
|
||||
final_state = event_data['state']
|
||||
# tokens_used = event_data["tokens_used"]
|
||||
|
||||
out = {
|
||||
"messages": messages,
|
||||
"state": final_state,
|
||||
}
|
||||
|
||||
logger.info("Output:")
|
||||
print("Output:")
|
||||
for k, v in out.items():
|
||||
logger.info(f"{k}: {v}")
|
||||
logger.info('*'*100)
|
||||
print(f"{k}: {v}")
|
||||
print('*'*100)
|
||||
|
||||
return jsonify(out)
|
||||
|
||||
except Exception as e:
|
||||
print(traceback.format_exc())
|
||||
logger.error(f"Error: {str(e)}")
|
||||
print(f"Error: {str(e)}")
|
||||
return jsonify({"error": str(e)}), 500
|
||||
|
||||
def format_sse(data: dict, event: str = None) -> str:
|
||||
|
|
@ -133,6 +135,9 @@ async def chat_stream():
|
|||
print("Request:", request_data.decode('utf-8'))
|
||||
request_data = json.loads(request_data)
|
||||
|
||||
# Add enable_tracing from master_config to request_data
|
||||
request_data["enable_tracing"] = master_config.get("enable_tracing", False)
|
||||
|
||||
# filter out agent transfer messages
|
||||
input_messages = [msg for msg in request_data["messages"] if not is_agent_transfer_message(msg)]
|
||||
|
||||
|
|
@ -150,6 +155,7 @@ async def chat_stream():
|
|||
msg["role"] = "user"
|
||||
|
||||
async def generate():
|
||||
print("Running generate() in server")
|
||||
try:
|
||||
async for event_type, event_data in run_turn_streamed(
|
||||
messages=input_messages,
|
||||
|
|
@ -157,23 +163,21 @@ async def chat_stream():
|
|||
agent_configs=request_data.get("agents", []),
|
||||
tool_configs=request_data.get("tools", []),
|
||||
prompt_configs=request_data.get("prompts", []),
|
||||
start_turn_with_start_agent=config.get("start_turn_with_start_agent", False),
|
||||
start_turn_with_start_agent=master_config.get("start_turn_with_start_agent", False),
|
||||
max_calls_per_child_agent=master_config.get("max_calls_per_child_agent", 1),
|
||||
state=request_data.get("state", {}),
|
||||
additional_tool_configs=[RAG_TOOL, CLOSE_CHAT_TOOL],
|
||||
complete_request=request_data
|
||||
):
|
||||
if event_type == 'message':
|
||||
print("Yielding message:")
|
||||
yield format_sse(event_data, "message")
|
||||
elif event_type == 'done':
|
||||
print("Yielding done:")
|
||||
yield format_sse(event_data, "done")
|
||||
elif event_type == 'error':
|
||||
print("Yielding error:")
|
||||
yield format_sse(event_data, "stream_error")
|
||||
yield format_sse(event_data, " error")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Streaming error: {str(e)}")
|
||||
print(f"Streaming error: {str(e)}")
|
||||
yield format_sse({"error": str(e)}, "error")
|
||||
|
||||
return Response(generate(), mimetype='text/event-stream')
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue