mirror of
https://github.com/rowboatlabs/rowboat.git
synced 2026-06-06 19:35:44 +02:00
Made agent run function async (#49)
* fixed sync run function * fixed simulation
This commit is contained in:
parent
a56b2d3a90
commit
18812a4887
4 changed files with 52 additions and 47 deletions
|
|
@ -63,19 +63,30 @@ async def chat():
|
|||
logger.info('='*100)
|
||||
logger.info(f"{'*'*100}Running server mode{'*'*100}")
|
||||
try:
|
||||
data = await request.get_json()
|
||||
logger.info('Complete request:')
|
||||
logger.info(data)
|
||||
logger.info('-'*100)
|
||||
|
||||
start_time = datetime.now()
|
||||
request_data = await request.get_json()
|
||||
config = read_json_from_file("./configs/default_config.json")
|
||||
|
||||
# filter out agent transfer messages
|
||||
input_messages = [msg for msg in data.get("messages", []) if not is_agent_transfer_message(msg)]
|
||||
input_messages = [msg for msg in request_data["messages"] if not is_agent_transfer_message(msg)]
|
||||
|
||||
logger.info('Beginning turn')
|
||||
resp_messages, resp_tokens_used, resp_state = run_turn(
|
||||
# Preprocess messages to handle null content and role issues
|
||||
for msg in input_messages:
|
||||
if (msg.get("role") == "assistant" and
|
||||
msg.get("content") is None and
|
||||
msg.get("tool_calls") is not None and
|
||||
len(msg.get("tool_calls")) > 0):
|
||||
msg["content"] = "Calling tool"
|
||||
|
||||
if msg.get("role") == "tool":
|
||||
msg["role"] = "developer"
|
||||
elif not msg.get("role"):
|
||||
msg["role"] = "user"
|
||||
|
||||
print("Request:")
|
||||
pprint(request_data)
|
||||
|
||||
data = request_data
|
||||
resp_messages, resp_tokens_used, resp_state = await run_turn(
|
||||
messages=input_messages,
|
||||
start_agent_name=data.get("startAgent", ""),
|
||||
agent_configs=data.get("agents", []),
|
||||
|
|
@ -101,9 +112,6 @@ async def chat():
|
|||
logger.info(f"{k}: {v}")
|
||||
logger.info('*'*100)
|
||||
|
||||
logger.info('='*100)
|
||||
logger.info(f"Processing time: {datetime.now() - start_time}")
|
||||
|
||||
return jsonify(out)
|
||||
|
||||
except Exception as e:
|
||||
|
|
@ -138,18 +146,18 @@ async def chat_stream(stream_id):
|
|||
|
||||
request_data = json.loads(request_data)
|
||||
config = read_json_from_file("./configs/default_config.json")
|
||||
|
||||
|
||||
# filter out agent transfer messages
|
||||
input_messages = [msg for msg in request_data["messages"] if not is_agent_transfer_message(msg)]
|
||||
|
||||
# Preprocess messages to handle null content and role issues
|
||||
for msg in input_messages:
|
||||
if (msg.get("role") == "assistant" and
|
||||
msg.get("content") is None and
|
||||
msg.get("tool_calls") is not None and
|
||||
if (msg.get("role") == "assistant" and
|
||||
msg.get("content") is None and
|
||||
msg.get("tool_calls") is not None and
|
||||
len(msg.get("tool_calls")) > 0):
|
||||
msg["content"] = "Calling tool"
|
||||
|
||||
|
||||
if msg.get("role") == "tool":
|
||||
msg["role"] = "developer"
|
||||
elif not msg.get("role"):
|
||||
|
|
|
|||
|
|
@ -81,7 +81,7 @@ def create_final_response(response, turn_messages, tokens_used, all_agents):
|
|||
return response.messages, response.tokens_used, new_state
|
||||
|
||||
|
||||
def run_turn(
|
||||
async def run_turn(
|
||||
messages, start_agent_name, agent_configs, tool_configs, start_turn_with_start_agent, state={}, additional_tool_configs=[], complete_request={}
|
||||
):
|
||||
"""
|
||||
|
|
@ -147,7 +147,7 @@ def run_turn(
|
|||
logger.info("Running swarm run")
|
||||
print("Running swarm run")
|
||||
|
||||
response = swarm_run(
|
||||
response = await swarm_run(
|
||||
agent=last_new_agent,
|
||||
messages=messages,
|
||||
external_tools=external_tools,
|
||||
|
|
@ -231,12 +231,12 @@ def run_turn(
|
|||
)
|
||||
|
||||
async def run_turn_streamed(
|
||||
messages,
|
||||
start_agent_name,
|
||||
agent_configs,
|
||||
tool_configs,
|
||||
start_turn_with_start_agent,
|
||||
state={},
|
||||
messages,
|
||||
start_agent_name,
|
||||
agent_configs,
|
||||
tool_configs,
|
||||
start_turn_with_start_agent,
|
||||
state={},
|
||||
additional_tool_configs=[],
|
||||
complete_request={}
|
||||
):
|
||||
|
|
@ -254,7 +254,7 @@ async def run_turn_streamed(
|
|||
)
|
||||
last_new_agent = get_agent_by_name(last_agent_name, new_agents)
|
||||
external_tools = get_external_tools(tool_configs)
|
||||
|
||||
|
||||
current_agent = last_new_agent
|
||||
tokens_used = {"total": 0, "prompt": 0, "completion": 0}
|
||||
|
||||
|
|
@ -270,7 +270,7 @@ async def run_turn_streamed(
|
|||
print('='*50)
|
||||
print("Received event: ", event)
|
||||
print('-'*50)
|
||||
|
||||
|
||||
# Handle raw response events and accumulate tokens
|
||||
if event.type == "raw_response_event":
|
||||
if hasattr(event.data, 'type') and event.data.type == "response.completed":
|
||||
|
|
@ -282,7 +282,7 @@ async def run_turn_streamed(
|
|||
print(f"Found usage information. Updated cumulative tokens: {tokens_used}")
|
||||
print('-'*50)
|
||||
continue
|
||||
|
||||
|
||||
# Update current agent when it changes
|
||||
elif event.type == "agent_updated_stream_event":
|
||||
current_agent = event.new_agent
|
||||
|
|
@ -323,10 +323,10 @@ async def run_turn_streamed(
|
|||
}
|
||||
print("Yielding message: ", message)
|
||||
yield ('message', message)
|
||||
|
||||
|
||||
current_agent = event.new_agent
|
||||
continue
|
||||
|
||||
|
||||
# Handle run items (tools, messages, etc)
|
||||
elif event.type == "run_item_stream_event":
|
||||
if event.item.type == "tool_call_item":
|
||||
|
|
@ -348,7 +348,7 @@ async def run_turn_streamed(
|
|||
}
|
||||
print("Yielding message: ", message)
|
||||
yield ('message', message)
|
||||
|
||||
|
||||
elif event.item.type == "tool_call_output_item":
|
||||
message = {
|
||||
'content': str(event.item.output),
|
||||
|
|
@ -361,7 +361,7 @@ async def run_turn_streamed(
|
|||
}
|
||||
print("Yielding message: ", message)
|
||||
yield ('message', message)
|
||||
|
||||
|
||||
elif event.item.type == "message_output_item":
|
||||
content = ""
|
||||
if hasattr(event.item.raw_item, 'content'):
|
||||
|
|
|
|||
|
|
@ -311,7 +311,7 @@ def create_response(messages=None, tokens_used=None, agent=None, error_msg=''):
|
|||
)
|
||||
|
||||
|
||||
def run(
|
||||
async def run(
|
||||
agent,
|
||||
messages,
|
||||
external_tools=None,
|
||||
|
|
@ -344,19 +344,11 @@ def run(
|
|||
"content": str(msg)
|
||||
})
|
||||
|
||||
# Create a new event loop for this thread
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
except RuntimeError:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
# Run the agent with the formatted messages
|
||||
logger.info("Beginning Swarm run with run_sync")
|
||||
print("Beginning Swarm run with run_sync")
|
||||
logger.info("Beginning Swarm run")
|
||||
print("Beginning Swarm run")
|
||||
|
||||
try:
|
||||
response = loop.run_until_complete(Runner.run(agent, formatted_messages))
|
||||
response = await Runner.run(agent, formatted_messages)
|
||||
except Exception as e:
|
||||
logger.error(f"Error during run: {str(e)}")
|
||||
print(f"Error during run: {str(e)}")
|
||||
|
|
|
|||
|
|
@ -42,7 +42,7 @@ async def simulate_simulation(
|
|||
{
|
||||
"role": "system",
|
||||
"content": (
|
||||
f"You are a customer talking to a chatbot. Have the following chat with the chatbot. Scenario:\n{scenario.description}. You are provided no other information. If the chatbot asks you for information that is not in context, go ahead and provide one unless stated otherwise in the scenario. Directly have the chat with the chatbot. Start now."
|
||||
f"You are role playing a customer talking to a chatbot (the user is role playing the chatbot). Have the following chat with the chatbot. Scenario:\n{scenario.description}. You are provided no other information. If the chatbot asks you for information that is not in context, go ahead and provide one unless stated otherwise in the scenario. Directly have the chat with the chatbot. Start now with your first message."
|
||||
)
|
||||
}
|
||||
]
|
||||
|
|
@ -64,20 +64,25 @@ async def simulate_simulation(
|
|||
)
|
||||
|
||||
simulated_content = simulated_user_response.choices[0].message.content.strip()
|
||||
messages.append({"role": "user", "content": simulated_content})
|
||||
messages.append({"role": "assistant", "content": simulated_content})
|
||||
# Run Rowboat chat in a thread if it's synchronous
|
||||
rowboat_response = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: support_chat.run(simulated_content)
|
||||
)
|
||||
|
||||
messages.append({"role": "assistant", "content": rowboat_response})
|
||||
messages.append({"role": "user", "content": rowboat_response})
|
||||
|
||||
# -------------------------
|
||||
# (2) EVALUATION STEP
|
||||
# -------------------------
|
||||
# swap the roles of the assistant and the user
|
||||
transcript_str = ""
|
||||
for m in messages:
|
||||
if m.get("role") == "assistant":
|
||||
m["role"] = "user"
|
||||
elif m.get("role") == "user":
|
||||
m["role"] = "assistant"
|
||||
role = m.get("role", "unknown")
|
||||
content = m.get("content", "")
|
||||
transcript_str += f"{role.upper()}: {content}\n"
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue