mirror of
https://github.com/rowboatlabs/rowboat.git
synced 2026-05-05 21:32:46 +02:00
tool invocation
This commit is contained in:
parent
b131c1768e
commit
b2fd9bf877
7 changed files with 574 additions and 162 deletions
|
|
@ -14,15 +14,14 @@ from agents import Agent as NewAgent, Runner, FunctionTool, RunContextWrapper
|
|||
# Add import for OpenAI functionality
|
||||
from src.utils.common import common_logger as logger, generate_openai_output
|
||||
from typing import Any
|
||||
# Create a dedicated logger for swarm wrapper
|
||||
#logger = logging.getLogger("swarm_wrapper")
|
||||
#logger.setLevel(logging.INFO)
|
||||
from dataclasses import asdict
|
||||
import asyncio
|
||||
from mcp import ClientSession
|
||||
from mcp.client.sse import sse_client
|
||||
|
||||
from pydantic import BaseModel
|
||||
from typing import List, Optional, Dict
|
||||
from .tool_calling import call_rag_tool
|
||||
|
||||
class NewResponse(BaseModel):
|
||||
messages: List[Dict]
|
||||
|
|
@ -30,7 +29,7 @@ class NewResponse(BaseModel):
|
|||
tokens_used: Optional[dict] = {}
|
||||
error_msg: Optional[str] = ""
|
||||
|
||||
async def mock_tool(tool_name: str, args: str, mock_instructions: str) -> str:
|
||||
async def mock_tool(tool_name: str, args: str, tool_config: str) -> str:
|
||||
"""
|
||||
Handles tool execution by either using mock instructions or generating a response.
|
||||
|
||||
|
|
@ -45,10 +44,11 @@ async def mock_tool(tool_name: str, args: str, mock_instructions: str) -> str:
|
|||
print(f"Mock tool called for: {tool_name}")
|
||||
|
||||
# For non-mocked tools, generate a realistic response
|
||||
description = mock_instructions
|
||||
description = tool_config.get("description", "")
|
||||
mock_instructions = tool_config.get("mockInstructions", "")
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": f"You are simulating the execution of a tool called '{tool_name}'. Here are the mock instructions: {description}. Generate a realistic response as if the tool was actually executed with the given parameters."},
|
||||
{"role": "system", "content": f"You are simulating the execution of a tool called '{tool_name}'.Here is the description of the tool: {description}. Here are the instructions for the mock tool: {mock_instructions}. Generate a realistic response as if the tool was actually executed with the given parameters."},
|
||||
{"role": "user", "content": f"Generate a realistic response for the tool '{tool_name}' with these parameters: {args}. The response should be concise and focused on what the tool would actually return."}
|
||||
]
|
||||
|
||||
|
|
@ -56,7 +56,7 @@ async def mock_tool(tool_name: str, args: str, mock_instructions: str) -> str:
|
|||
response_content = generate_openai_output(messages, output_type='text', model="gpt-4o")
|
||||
return response_content
|
||||
|
||||
async def call_webhook(tool_name: str, args: str) -> str:
|
||||
async def call_webhook(tool_name: str, args: str, webhook_url: str) -> str:
|
||||
"""
|
||||
Calls the webhook with the given tool name and arguments.
|
||||
|
||||
|
|
@ -67,12 +67,11 @@ async def call_webhook(tool_name: str, args: str) -> str:
|
|||
Returns:
|
||||
str: The response from the webhook, or an error message if the call fails.
|
||||
"""
|
||||
webhook_url = "http://localhost:4020/tool_call"
|
||||
content_dict = {
|
||||
"toolCall": {
|
||||
"function": {
|
||||
"name": tool_name,
|
||||
"arguments": args # Assumes args is a valid JSON string
|
||||
"arguments": args
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -93,33 +92,28 @@ async def call_webhook(tool_name: str, args: str) -> str:
|
|||
print(f"Exception in call_webhook: {str(e)}")
|
||||
return f"Error: Failed to call webhook - {str(e)}"
|
||||
|
||||
async def call_mcp(tool_name: str, args: str, mcp_server_name: str, mcp_servers: dict) -> str:
|
||||
async def call_mcp(tool_name: str, args: str, mcp_server_url: str) -> str:
|
||||
"""
|
||||
Calls the MCP with the given tool name and arguments.
|
||||
"""
|
||||
server_url = "http://localhost:8000/sse" #mcp_servers.get(tool_name, None)
|
||||
print(args)
|
||||
async with sse_client(url=server_url) as streams:
|
||||
# Create a client session using the SSE streams
|
||||
|
||||
async with sse_client(url=mcp_server_url) as streams:
|
||||
async with ClientSession(*streams) as session:
|
||||
# Initialize the session (perform handshake with the server)
|
||||
await session.initialize()
|
||||
# Call the tool on the server and await the response
|
||||
response = await session.call_tool(tool_name, arguments=json.loads(args))
|
||||
jargs = json.loads(args)
|
||||
response = await session.call_tool(tool_name, arguments=jargs)
|
||||
json_output = json.dumps([item.__dict__ for item in response.content], indent=2)
|
||||
|
||||
# Print the response received from the server
|
||||
print("Server response:", response)
|
||||
return json_output
|
||||
|
||||
return response
|
||||
|
||||
def catch_all(ctx: RunContextWrapper[Any], args: str, tool_name: str, tool_config: dict) -> str:
|
||||
async def catch_all(ctx: RunContextWrapper[Any], args: str, tool_name: str, tool_config: dict, complete_request: dict) -> str:
|
||||
"""
|
||||
Handles all tool calls by dispatching to appropriate functions.
|
||||
"""
|
||||
print(f"Catch all called for tool: {tool_name}")
|
||||
print(f"Args: {args}")
|
||||
print(f"Tool config: {tool_config}")
|
||||
|
||||
|
||||
# Create event loop for async operations
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
|
|
@ -128,30 +122,54 @@ def catch_all(ctx: RunContextWrapper[Any], args: str, tool_name: str, tool_confi
|
|||
asyncio.set_event_loop(loop)
|
||||
|
||||
response_content = None
|
||||
# Check if this tool should be mocked
|
||||
if tool_config.get("mockTool", False):
|
||||
# Handle mock tool synchronously
|
||||
description = tool_config.get("description", "")
|
||||
messages = [
|
||||
{"role": "system", "content": f"You are simulating the execution of a tool called '{tool_name}'. The tool has this description: {description}. Generate a realistic response as if the tool was actually executed with the given parameters."},
|
||||
{"role": "user", "content": f"Generate a realistic response for the tool '{tool_name}' with these parameters: {args}. The response should be concise and focused on what the tool would actually return."}
|
||||
]
|
||||
response_content = generate_openai_output(messages, output_type='text', model="gpt-4o")
|
||||
# Call mock_tool to handle the response (it will decide whether to use mock instructions or generate a response)
|
||||
response_content = await mock_tool(tool_name, args, tool_config)
|
||||
print(response_content)
|
||||
elif tool_config.get("isMcp", False):
|
||||
# Handle MCP calls
|
||||
response_content = loop.run_until_complete(
|
||||
call_mcp(tool_name, args, tool_config.get("mcpServerName", ""), {})
|
||||
)
|
||||
mcp_server_name = tool_config.get("mcpServerName", "")
|
||||
mcp_servers = complete_request.get("mcpServers", {})
|
||||
mcp_server_url = next((server.get("url", "") for server in mcp_servers if server.get("name") == mcp_server_name), "")
|
||||
response_content = await call_mcp(tool_name, args, mcp_server_url)
|
||||
else:
|
||||
# Handle webhook calls
|
||||
response_content = loop.run_until_complete(
|
||||
call_webhook(tool_name, args)
|
||||
)
|
||||
|
||||
print(response_content)
|
||||
webhook_url = complete_request.get("toolWebhookUrl", "")
|
||||
response_content = await call_webhook(tool_name, args, webhook_url)
|
||||
return response_content
|
||||
|
||||
def get_agents(agent_configs, tool_configs):
|
||||
|
||||
def get_rag_tool(config: dict, complete_request: dict) -> FunctionTool:
|
||||
"""
|
||||
Creates a RAG tool based on the provided configuration.
|
||||
"""
|
||||
project_id = complete_request.get("projectId", "")
|
||||
if config.get("ragDataSources", None):
|
||||
print("getArticleInfo")
|
||||
params = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "The query to search for"
|
||||
}
|
||||
},
|
||||
"additionalProperties": False,
|
||||
"required": [
|
||||
"query"
|
||||
]
|
||||
}
|
||||
tool = FunctionTool(
|
||||
name="getArticleInfo",
|
||||
description="Get information about an article",
|
||||
params_json_schema=params,
|
||||
on_invoke_tool=lambda ctx, args: call_rag_tool(project_id, json.loads(args)['query'], config.get("ragDataSources", []), "chunks", 3)
|
||||
)
|
||||
return tool
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
|
||||
def get_agents(agent_configs, tool_configs, complete_request):
|
||||
"""
|
||||
Creates and initializes Agent objects based on their configurations and connections.
|
||||
"""
|
||||
|
|
@ -181,7 +199,15 @@ def get_agents(agent_configs, tool_configs):
|
|||
print(f"Agent {agent_config['name']} has {len(agent_config['tools'])} configured tools")
|
||||
|
||||
new_tools = []
|
||||
print(agent_config)
|
||||
rag_tool = get_rag_tool(agent_config, complete_request)
|
||||
if rag_tool:
|
||||
new_tools.append(rag_tool)
|
||||
logger.debug(f"Added rag tool to agent {agent_config['name']}")
|
||||
print(f"Added rag tool to agent {agent_config['name']}")
|
||||
|
||||
for tool_name in agent_config["tools"]:
|
||||
|
||||
tool_config = get_tool_config_by_name(tool_configs, tool_name)
|
||||
|
||||
if tool_config:
|
||||
|
|
@ -195,8 +221,8 @@ def get_agents(agent_configs, tool_configs):
|
|||
name=tool_name,
|
||||
description=tool_config["description"],
|
||||
params_json_schema=tool_config["parameters"],
|
||||
on_invoke_tool=lambda ctx, args, _tool_name=tool_name, _tool_config=tool_config:
|
||||
catch_all(ctx, args, _tool_name, _tool_config)
|
||||
on_invoke_tool=lambda ctx, args, _tool_name=tool_name, _tool_config=tool_config, _complete_request=complete_request:
|
||||
catch_all(ctx, args, _tool_name, _tool_config, _complete_request)
|
||||
)
|
||||
new_tools.append(tool)
|
||||
logger.debug(f"Added tool {tool_name} to agent {agent_config['name']}")
|
||||
|
|
@ -305,7 +331,7 @@ def run(
|
|||
# Run the agent with the formatted messages
|
||||
logger.info("Beginning Swarm run with run_sync")
|
||||
print("Beginning Swarm run with run_sync")
|
||||
|
||||
|
||||
try:
|
||||
response = loop.run_until_complete(Runner.run(agent, formatted_messages))
|
||||
except Exception as e:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue