mirror of
https://github.com/rowboatlabs/rowboat.git
synced 2026-05-05 21:32:46 +02:00
Add async/sync wrapper over run
This commit is contained in:
parent
a2a19f37d7
commit
e725fc6276
5 changed files with 777 additions and 39 deletions
|
|
@ -112,34 +112,43 @@ async def call_mcp(tool_name: str, args: str, mcp_server_name: str, mcp_servers:
|
|||
|
||||
return response
|
||||
|
||||
async def catch_all(ctx: RunContextWrapper[Any], args: str, tool_name: str, tool_config: dict) -> str:
|
||||
def catch_all(ctx: RunContextWrapper[Any], args: str, tool_name: str, tool_config: dict) -> str:
|
||||
"""
|
||||
Handles all tool calls by dispatching to the mock_tool function.
|
||||
|
||||
Args:
|
||||
ctx: The run context wrapper
|
||||
args: The arguments passed to the tool
|
||||
tool_name: The name of the tool being called
|
||||
tool_config: The configuration of the tool
|
||||
|
||||
Returns:
|
||||
The response from the tool
|
||||
Handles all tool calls by dispatching to appropriate functions.
|
||||
"""
|
||||
print(f"Catch all called for tool: {tool_name}")
|
||||
print(f"Args: {args}")
|
||||
print(f"Tool config: {tool_config}")
|
||||
|
||||
# Create event loop for async operations
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
except RuntimeError:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
response_content = None
|
||||
# Check if this tool should be mocked
|
||||
if tool_config.get("mockTool", False):
|
||||
# Call mock_tool to handle the response (it will decide whether to use mock instructions or generate a response)
|
||||
response_content = await mock_tool(tool_name, args, tool_config)
|
||||
print(response_content)
|
||||
# Handle mock tool synchronously
|
||||
description = tool_config.get("description", "")
|
||||
messages = [
|
||||
{"role": "system", "content": f"You are simulating the execution of a tool called '{tool_name}'. The tool has this description: {description}. Generate a realistic response as if the tool was actually executed with the given parameters."},
|
||||
{"role": "user", "content": f"Generate a realistic response for the tool '{tool_name}' with these parameters: {args}. The response should be concise and focused on what the tool would actually return."}
|
||||
]
|
||||
response_content = generate_openai_output(messages, output_type='text', model="gpt-4o")
|
||||
elif tool_config.get("isMcp", False):
|
||||
response_content = await call_mcp(tool_name, args, tool_config.get("mcpServerName", ""), {})
|
||||
print(response_content)
|
||||
# Handle MCP calls
|
||||
response_content = loop.run_until_complete(
|
||||
call_mcp(tool_name, args, tool_config.get("mcpServerName", ""), {})
|
||||
)
|
||||
else:
|
||||
response_content = await call_webhook(tool_name, args)
|
||||
print(response_content)
|
||||
# Handle webhook calls
|
||||
response_content = loop.run_until_complete(
|
||||
call_webhook(tool_name, args)
|
||||
)
|
||||
|
||||
print(response_content)
|
||||
return response_content
|
||||
|
||||
def get_agents(agent_configs, tool_configs):
|
||||
|
|
@ -261,19 +270,6 @@ def run(
|
|||
):
|
||||
"""
|
||||
Wrapper function for initializing and running the Swarm client.
|
||||
|
||||
Args:
|
||||
agent: The agent to run
|
||||
messages: List of messages for the agent to process
|
||||
execute_tools: Whether to execute tools or just return tool calls
|
||||
external_tools: List of external tools available to the agent
|
||||
localize_history: Whether to localize history for the agent
|
||||
parent_has_child_history: Whether parent agents have access to child agent history
|
||||
max_messages_per_turn: Maximum number of messages to process in a turn
|
||||
tokens_used: Dictionary tracking token usage
|
||||
|
||||
Returns:
|
||||
Response object from the Swarm client
|
||||
"""
|
||||
logger.info(f"Initializing Swarm client for agent: {agent.name}")
|
||||
print(f"Initializing Swarm client for agent: {agent.name}")
|
||||
|
|
@ -287,26 +283,36 @@ def run(
|
|||
# Format messages to ensure they're compatible with the OpenAI API
|
||||
formatted_messages = []
|
||||
for msg in messages:
|
||||
# Check if the message has the expected format
|
||||
if isinstance(msg, dict) and "content" in msg:
|
||||
# Make sure the message has the required fields for OpenAI API
|
||||
formatted_msg = {
|
||||
"role": msg.get("role", "user"),
|
||||
"content": msg["content"]
|
||||
}
|
||||
formatted_messages.append(formatted_msg)
|
||||
else:
|
||||
# If the message is just a string, assume it's a user message
|
||||
formatted_messages.append({
|
||||
"role": "user",
|
||||
"content": str(msg)
|
||||
})
|
||||
|
||||
# Create a new event loop for this thread
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
except RuntimeError:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
# Run the agent with the formatted messages
|
||||
logger.info("Beginning Swarm run with run_sync")
|
||||
print("Beginning Swarm run with run_sync")
|
||||
response2 = Runner.run_sync(agent, formatted_messages)
|
||||
|
||||
try:
|
||||
response = loop.run_until_complete(Runner.run(agent, formatted_messages))
|
||||
except Exception as e:
|
||||
logger.error(f"Error during run: {str(e)}")
|
||||
print(f"Error during run: {str(e)}")
|
||||
raise
|
||||
|
||||
logger.info(f"Completed Swarm run for agent: {agent.name}")
|
||||
print(f"Completed Swarm run for agent: {agent.name}")
|
||||
return response2
|
||||
return response
|
||||
Loading…
Add table
Add a link
Reference in a new issue