Made agent run function async (#49)

* fixed sync run function

* fixed simulation
This commit is contained in:
arkml 2025-03-28 00:51:53 +05:30 committed by GitHub
parent a56b2d3a90
commit 18812a4887
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 52 additions and 47 deletions

View file

@ -63,19 +63,30 @@ async def chat():
logger.info('='*100)
logger.info(f"{'*'*100}Running server mode{'*'*100}")
try:
data = await request.get_json()
logger.info('Complete request:')
logger.info(data)
logger.info('-'*100)
start_time = datetime.now()
request_data = await request.get_json()
config = read_json_from_file("./configs/default_config.json")
# filter out agent transfer messages
input_messages = [msg for msg in data.get("messages", []) if not is_agent_transfer_message(msg)]
input_messages = [msg for msg in request_data["messages"] if not is_agent_transfer_message(msg)]
logger.info('Beginning turn')
resp_messages, resp_tokens_used, resp_state = run_turn(
# Preprocess messages to handle null content and role issues
for msg in input_messages:
if (msg.get("role") == "assistant" and
msg.get("content") is None and
msg.get("tool_calls") is not None and
len(msg.get("tool_calls")) > 0):
msg["content"] = "Calling tool"
if msg.get("role") == "tool":
msg["role"] = "developer"
elif not msg.get("role"):
msg["role"] = "user"
print("Request:")
pprint(request_data)
data = request_data
resp_messages, resp_tokens_used, resp_state = await run_turn(
messages=input_messages,
start_agent_name=data.get("startAgent", ""),
agent_configs=data.get("agents", []),
@ -101,9 +112,6 @@ async def chat():
logger.info(f"{k}: {v}")
logger.info('*'*100)
logger.info('='*100)
logger.info(f"Processing time: {datetime.now() - start_time}")
return jsonify(out)
except Exception as e:
@ -138,18 +146,18 @@ async def chat_stream(stream_id):
request_data = json.loads(request_data)
config = read_json_from_file("./configs/default_config.json")
# filter out agent transfer messages
input_messages = [msg for msg in request_data["messages"] if not is_agent_transfer_message(msg)]
# Preprocess messages to handle null content and role issues
for msg in input_messages:
if (msg.get("role") == "assistant" and
msg.get("content") is None and
msg.get("tool_calls") is not None and
if (msg.get("role") == "assistant" and
msg.get("content") is None and
msg.get("tool_calls") is not None and
len(msg.get("tool_calls")) > 0):
msg["content"] = "Calling tool"
if msg.get("role") == "tool":
msg["role"] = "developer"
elif not msg.get("role"):

View file

@ -81,7 +81,7 @@ def create_final_response(response, turn_messages, tokens_used, all_agents):
return response.messages, response.tokens_used, new_state
def run_turn(
async def run_turn(
messages, start_agent_name, agent_configs, tool_configs, start_turn_with_start_agent, state={}, additional_tool_configs=[], complete_request={}
):
"""
@ -147,7 +147,7 @@ def run_turn(
logger.info("Running swarm run")
print("Running swarm run")
response = swarm_run(
response = await swarm_run(
agent=last_new_agent,
messages=messages,
external_tools=external_tools,
@ -231,12 +231,12 @@ def run_turn(
)
async def run_turn_streamed(
messages,
start_agent_name,
agent_configs,
tool_configs,
start_turn_with_start_agent,
state={},
messages,
start_agent_name,
agent_configs,
tool_configs,
start_turn_with_start_agent,
state={},
additional_tool_configs=[],
complete_request={}
):
@ -254,7 +254,7 @@ async def run_turn_streamed(
)
last_new_agent = get_agent_by_name(last_agent_name, new_agents)
external_tools = get_external_tools(tool_configs)
current_agent = last_new_agent
tokens_used = {"total": 0, "prompt": 0, "completion": 0}
@ -270,7 +270,7 @@ async def run_turn_streamed(
print('='*50)
print("Received event: ", event)
print('-'*50)
# Handle raw response events and accumulate tokens
if event.type == "raw_response_event":
if hasattr(event.data, 'type') and event.data.type == "response.completed":
@ -282,7 +282,7 @@ async def run_turn_streamed(
print(f"Found usage information. Updated cumulative tokens: {tokens_used}")
print('-'*50)
continue
# Update current agent when it changes
elif event.type == "agent_updated_stream_event":
current_agent = event.new_agent
@ -323,10 +323,10 @@ async def run_turn_streamed(
}
print("Yielding message: ", message)
yield ('message', message)
current_agent = event.new_agent
continue
# Handle run items (tools, messages, etc)
elif event.type == "run_item_stream_event":
if event.item.type == "tool_call_item":
@ -348,7 +348,7 @@ async def run_turn_streamed(
}
print("Yielding message: ", message)
yield ('message', message)
elif event.item.type == "tool_call_output_item":
message = {
'content': str(event.item.output),
@ -361,7 +361,7 @@ async def run_turn_streamed(
}
print("Yielding message: ", message)
yield ('message', message)
elif event.item.type == "message_output_item":
content = ""
if hasattr(event.item.raw_item, 'content'):

View file

@ -311,7 +311,7 @@ def create_response(messages=None, tokens_used=None, agent=None, error_msg=''):
)
def run(
async def run(
agent,
messages,
external_tools=None,
@ -344,19 +344,11 @@ def run(
"content": str(msg)
})
# Create a new event loop for this thread
try:
loop = asyncio.get_event_loop()
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
# Run the agent with the formatted messages
logger.info("Beginning Swarm run with run_sync")
print("Beginning Swarm run with run_sync")
logger.info("Beginning Swarm run")
print("Beginning Swarm run")
try:
response = loop.run_until_complete(Runner.run(agent, formatted_messages))
response = await Runner.run(agent, formatted_messages)
except Exception as e:
logger.error(f"Error during run: {str(e)}")
print(f"Error during run: {str(e)}")