mirror of
https://github.com/rowboatlabs/rowboat.git
synced 2026-05-05 13:22:38 +02:00
Made agent run function async (#49)
* fixed sync run function * fixed simulation
This commit is contained in:
parent
a56b2d3a90
commit
18812a4887
4 changed files with 52 additions and 47 deletions
|
|
@ -81,7 +81,7 @@ def create_final_response(response, turn_messages, tokens_used, all_agents):
|
|||
return response.messages, response.tokens_used, new_state
|
||||
|
||||
|
||||
def run_turn(
|
||||
async def run_turn(
|
||||
messages, start_agent_name, agent_configs, tool_configs, start_turn_with_start_agent, state={}, additional_tool_configs=[], complete_request={}
|
||||
):
|
||||
"""
|
||||
|
|
@ -147,7 +147,7 @@ def run_turn(
|
|||
logger.info("Running swarm run")
|
||||
print("Running swarm run")
|
||||
|
||||
response = swarm_run(
|
||||
response = await swarm_run(
|
||||
agent=last_new_agent,
|
||||
messages=messages,
|
||||
external_tools=external_tools,
|
||||
|
|
@ -231,12 +231,12 @@ def run_turn(
|
|||
)
|
||||
|
||||
async def run_turn_streamed(
|
||||
messages,
|
||||
start_agent_name,
|
||||
agent_configs,
|
||||
tool_configs,
|
||||
start_turn_with_start_agent,
|
||||
state={},
|
||||
messages,
|
||||
start_agent_name,
|
||||
agent_configs,
|
||||
tool_configs,
|
||||
start_turn_with_start_agent,
|
||||
state={},
|
||||
additional_tool_configs=[],
|
||||
complete_request={}
|
||||
):
|
||||
|
|
@ -254,7 +254,7 @@ async def run_turn_streamed(
|
|||
)
|
||||
last_new_agent = get_agent_by_name(last_agent_name, new_agents)
|
||||
external_tools = get_external_tools(tool_configs)
|
||||
|
||||
|
||||
current_agent = last_new_agent
|
||||
tokens_used = {"total": 0, "prompt": 0, "completion": 0}
|
||||
|
||||
|
|
@ -270,7 +270,7 @@ async def run_turn_streamed(
|
|||
print('='*50)
|
||||
print("Received event: ", event)
|
||||
print('-'*50)
|
||||
|
||||
|
||||
# Handle raw response events and accumulate tokens
|
||||
if event.type == "raw_response_event":
|
||||
if hasattr(event.data, 'type') and event.data.type == "response.completed":
|
||||
|
|
@ -282,7 +282,7 @@ async def run_turn_streamed(
|
|||
print(f"Found usage information. Updated cumulative tokens: {tokens_used}")
|
||||
print('-'*50)
|
||||
continue
|
||||
|
||||
|
||||
# Update current agent when it changes
|
||||
elif event.type == "agent_updated_stream_event":
|
||||
current_agent = event.new_agent
|
||||
|
|
@ -323,10 +323,10 @@ async def run_turn_streamed(
|
|||
}
|
||||
print("Yielding message: ", message)
|
||||
yield ('message', message)
|
||||
|
||||
|
||||
current_agent = event.new_agent
|
||||
continue
|
||||
|
||||
|
||||
# Handle run items (tools, messages, etc)
|
||||
elif event.type == "run_item_stream_event":
|
||||
if event.item.type == "tool_call_item":
|
||||
|
|
@ -348,7 +348,7 @@ async def run_turn_streamed(
|
|||
}
|
||||
print("Yielding message: ", message)
|
||||
yield ('message', message)
|
||||
|
||||
|
||||
elif event.item.type == "tool_call_output_item":
|
||||
message = {
|
||||
'content': str(event.item.output),
|
||||
|
|
@ -361,7 +361,7 @@ async def run_turn_streamed(
|
|||
}
|
||||
print("Yielding message: ", message)
|
||||
yield ('message', message)
|
||||
|
||||
|
||||
elif event.item.type == "message_output_item":
|
||||
content = ""
|
||||
if hasattr(event.item.raw_item, 'content'):
|
||||
|
|
|
|||
|
|
@ -311,7 +311,7 @@ def create_response(messages=None, tokens_used=None, agent=None, error_msg=''):
|
|||
)
|
||||
|
||||
|
||||
def run(
|
||||
async def run(
|
||||
agent,
|
||||
messages,
|
||||
external_tools=None,
|
||||
|
|
@ -344,19 +344,11 @@ def run(
|
|||
"content": str(msg)
|
||||
})
|
||||
|
||||
# Create a new event loop for this thread
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
except RuntimeError:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
# Run the agent with the formatted messages
|
||||
logger.info("Beginning Swarm run with run_sync")
|
||||
print("Beginning Swarm run with run_sync")
|
||||
logger.info("Beginning Swarm run")
|
||||
print("Beginning Swarm run")
|
||||
|
||||
try:
|
||||
response = loop.run_until_complete(Runner.run(agent, formatted_messages))
|
||||
response = await Runner.run(agent, formatted_messages)
|
||||
except Exception as e:
|
||||
logger.error(f"Error during run: {str(e)}")
|
||||
print(f"Error during run: {str(e)}")
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue