tool invocation

This commit is contained in:
arkml 2025-03-24 16:10:43 +05:30 committed by Ramnique Singh
parent b131c1768e
commit b2fd9bf877
7 changed files with 574 additions and 162 deletions

431
apps/agents/poetry.lock generated

File diff suppressed because it is too large Load diff

View file

@ -40,6 +40,7 @@ lxml = "^5.3.0"
markdownify = "^0.13.1" markdownify = "^0.13.1"
MarkupSafe = "^3.0.2" MarkupSafe = "^3.0.2"
mcp = "*" mcp = "*"
motor = "^3.7.0"
mypy-extensions = "^1.0.0" mypy-extensions = "^1.0.0"
nest-asyncio = "^1.6.0" nest-asyncio = "^1.6.0"
numpy = "^2.1.2" numpy = "^2.1.2"
@ -54,6 +55,8 @@ python-dateutil = "^2.8.2"
python-docx = "^1.1.2" python-docx = "^1.1.2"
python-dotenv = "^1.0.1" python-dotenv = "^1.0.1"
pytz = "^2024.2" pytz = "^2024.2"
qdrant_client = "^1.13.3"
redis = "^5.2.1"
requests = "^2.32.3" requests = "^2.32.3"
setuptools = "^75.1.0" setuptools = "^75.1.0"
six = "^1.16.0" six = "^1.16.0"

View file

@ -61,7 +61,8 @@ def chat():
tool_configs=data.get("tools", []), tool_configs=data.get("tools", []),
start_turn_with_start_agent=config.get("start_turn_with_start_agent", False), start_turn_with_start_agent=config.get("start_turn_with_start_agent", False),
state=data.get("state", {}), state=data.get("state", {}),
additional_tool_configs=[RAG_TOOL, CLOSE_CHAT_TOOL] additional_tool_configs=[RAG_TOOL, CLOSE_CHAT_TOOL],
complete_request=data
) )
logger.info('-'*200) logger.info('-'*200)

View file

@ -80,7 +80,7 @@ def create_final_response(response, turn_messages, tokens_used, all_agents):
def run_turn( def run_turn(
messages, start_agent_name, agent_configs, tool_configs, start_turn_with_start_agent, state={}, additional_tool_configs=[] messages, start_agent_name, agent_configs, tool_configs, start_turn_with_start_agent, state={}, additional_tool_configs=[], complete_request={}
): ):
""" """
Coordinates a single 'turn' of conversation or processing among agents. Coordinates a single 'turn' of conversation or processing among agents.
@ -129,7 +129,8 @@ def run_turn(
print("Initializing agents") print("Initializing agents")
new_agents = get_agents( new_agents = get_agents(
agent_configs=agent_configs, agent_configs=agent_configs,
tool_configs=tool_configs tool_configs=tool_configs,
complete_request=complete_request
) )
# Prepare escalation agent # Prepare escalation agent
last_new_agent = get_agent_by_name(last_agent_name, new_agents) last_new_agent = get_agent_by_name(last_agent_name, new_agents)

View file

@ -14,15 +14,14 @@ from agents import Agent as NewAgent, Runner, FunctionTool, RunContextWrapper
# Add import for OpenAI functionality # Add import for OpenAI functionality
from src.utils.common import common_logger as logger, generate_openai_output from src.utils.common import common_logger as logger, generate_openai_output
from typing import Any from typing import Any
# Create a dedicated logger for swarm wrapper from dataclasses import asdict
#logger = logging.getLogger("swarm_wrapper")
#logger.setLevel(logging.INFO)
import asyncio import asyncio
from mcp import ClientSession from mcp import ClientSession
from mcp.client.sse import sse_client from mcp.client.sse import sse_client
from pydantic import BaseModel from pydantic import BaseModel
from typing import List, Optional, Dict from typing import List, Optional, Dict
from .tool_calling import call_rag_tool
class NewResponse(BaseModel): class NewResponse(BaseModel):
messages: List[Dict] messages: List[Dict]
@ -30,7 +29,7 @@ class NewResponse(BaseModel):
tokens_used: Optional[dict] = {} tokens_used: Optional[dict] = {}
error_msg: Optional[str] = "" error_msg: Optional[str] = ""
async def mock_tool(tool_name: str, args: str, mock_instructions: str) -> str: async def mock_tool(tool_name: str, args: str, tool_config: str) -> str:
""" """
Handles tool execution by either using mock instructions or generating a response. Handles tool execution by either using mock instructions or generating a response.
@ -45,10 +44,11 @@ async def mock_tool(tool_name: str, args: str, mock_instructions: str) -> str:
print(f"Mock tool called for: {tool_name}") print(f"Mock tool called for: {tool_name}")
# For non-mocked tools, generate a realistic response # For non-mocked tools, generate a realistic response
description = mock_instructions description = tool_config.get("description", "")
mock_instructions = tool_config.get("mockInstructions", "")
messages = [ messages = [
{"role": "system", "content": f"You are simulating the execution of a tool called '{tool_name}'. Here are the mock instructions: {description}. Generate a realistic response as if the tool was actually executed with the given parameters."}, {"role": "system", "content": f"You are simulating the execution of a tool called '{tool_name}'.Here is the description of the tool: {description}. Here are the instructions for the mock tool: {mock_instructions}. Generate a realistic response as if the tool was actually executed with the given parameters."},
{"role": "user", "content": f"Generate a realistic response for the tool '{tool_name}' with these parameters: {args}. The response should be concise and focused on what the tool would actually return."} {"role": "user", "content": f"Generate a realistic response for the tool '{tool_name}' with these parameters: {args}. The response should be concise and focused on what the tool would actually return."}
] ]
@ -56,7 +56,7 @@ async def mock_tool(tool_name: str, args: str, mock_instructions: str) -> str:
response_content = generate_openai_output(messages, output_type='text', model="gpt-4o") response_content = generate_openai_output(messages, output_type='text', model="gpt-4o")
return response_content return response_content
async def call_webhook(tool_name: str, args: str) -> str: async def call_webhook(tool_name: str, args: str, webhook_url: str) -> str:
""" """
Calls the webhook with the given tool name and arguments. Calls the webhook with the given tool name and arguments.
@ -67,12 +67,11 @@ async def call_webhook(tool_name: str, args: str) -> str:
Returns: Returns:
str: The response from the webhook, or an error message if the call fails. str: The response from the webhook, or an error message if the call fails.
""" """
webhook_url = "http://localhost:4020/tool_call"
content_dict = { content_dict = {
"toolCall": { "toolCall": {
"function": { "function": {
"name": tool_name, "name": tool_name,
"arguments": args # Assumes args is a valid JSON string "arguments": args
} }
} }
} }
@ -93,26 +92,21 @@ async def call_webhook(tool_name: str, args: str) -> str:
print(f"Exception in call_webhook: {str(e)}") print(f"Exception in call_webhook: {str(e)}")
return f"Error: Failed to call webhook - {str(e)}" return f"Error: Failed to call webhook - {str(e)}"
async def call_mcp(tool_name: str, args: str, mcp_server_name: str, mcp_servers: dict) -> str: async def call_mcp(tool_name: str, args: str, mcp_server_url: str) -> str:
""" """
Calls the MCP with the given tool name and arguments. Calls the MCP with the given tool name and arguments.
""" """
server_url = "http://localhost:8000/sse" #mcp_servers.get(tool_name, None)
print(args) async with sse_client(url=mcp_server_url) as streams:
async with sse_client(url=server_url) as streams:
# Create a client session using the SSE streams
async with ClientSession(*streams) as session: async with ClientSession(*streams) as session:
# Initialize the session (perform handshake with the server)
await session.initialize() await session.initialize()
# Call the tool on the server and await the response jargs = json.loads(args)
response = await session.call_tool(tool_name, arguments=json.loads(args)) response = await session.call_tool(tool_name, arguments=jargs)
json_output = json.dumps([item.__dict__ for item in response.content], indent=2)
# Print the response received from the server return json_output
print("Server response:", response)
return response async def catch_all(ctx: RunContextWrapper[Any], args: str, tool_name: str, tool_config: dict, complete_request: dict) -> str:
def catch_all(ctx: RunContextWrapper[Any], args: str, tool_name: str, tool_config: dict) -> str:
""" """
Handles all tool calls by dispatching to appropriate functions. Handles all tool calls by dispatching to appropriate functions.
""" """
@ -128,30 +122,54 @@ def catch_all(ctx: RunContextWrapper[Any], args: str, tool_name: str, tool_confi
asyncio.set_event_loop(loop) asyncio.set_event_loop(loop)
response_content = None response_content = None
# Check if this tool should be mocked
if tool_config.get("mockTool", False): if tool_config.get("mockTool", False):
# Handle mock tool synchronously # Call mock_tool to handle the response (it will decide whether to use mock instructions or generate a response)
description = tool_config.get("description", "") response_content = await mock_tool(tool_name, args, tool_config)
messages = [
{"role": "system", "content": f"You are simulating the execution of a tool called '{tool_name}'. The tool has this description: {description}. Generate a realistic response as if the tool was actually executed with the given parameters."},
{"role": "user", "content": f"Generate a realistic response for the tool '{tool_name}' with these parameters: {args}. The response should be concise and focused on what the tool would actually return."}
]
response_content = generate_openai_output(messages, output_type='text', model="gpt-4o")
elif tool_config.get("isMcp", False):
# Handle MCP calls
response_content = loop.run_until_complete(
call_mcp(tool_name, args, tool_config.get("mcpServerName", ""), {})
)
else:
# Handle webhook calls
response_content = loop.run_until_complete(
call_webhook(tool_name, args)
)
print(response_content) print(response_content)
elif tool_config.get("isMcp", False):
mcp_server_name = tool_config.get("mcpServerName", "")
mcp_servers = complete_request.get("mcpServers", {})
mcp_server_url = next((server.get("url", "") for server in mcp_servers if server.get("name") == mcp_server_name), "")
response_content = await call_mcp(tool_name, args, mcp_server_url)
else:
webhook_url = complete_request.get("toolWebhookUrl", "")
response_content = await call_webhook(tool_name, args, webhook_url)
return response_content return response_content
def get_agents(agent_configs, tool_configs):
def get_rag_tool(config: dict, complete_request: dict) -> FunctionTool:
"""
Creates a RAG tool based on the provided configuration.
"""
project_id = complete_request.get("projectId", "")
if config.get("ragDataSources", None):
print("getArticleInfo")
params = {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "The query to search for"
}
},
"additionalProperties": False,
"required": [
"query"
]
}
tool = FunctionTool(
name="getArticleInfo",
description="Get information about an article",
params_json_schema=params,
on_invoke_tool=lambda ctx, args: call_rag_tool(project_id, json.loads(args)['query'], config.get("ragDataSources", []), "chunks", 3)
)
return tool
else:
return None
def get_agents(agent_configs, tool_configs, complete_request):
""" """
Creates and initializes Agent objects based on their configurations and connections. Creates and initializes Agent objects based on their configurations and connections.
""" """
@ -181,7 +199,15 @@ def get_agents(agent_configs, tool_configs):
print(f"Agent {agent_config['name']} has {len(agent_config['tools'])} configured tools") print(f"Agent {agent_config['name']} has {len(agent_config['tools'])} configured tools")
new_tools = [] new_tools = []
print(agent_config)
rag_tool = get_rag_tool(agent_config, complete_request)
if rag_tool:
new_tools.append(rag_tool)
logger.debug(f"Added rag tool to agent {agent_config['name']}")
print(f"Added rag tool to agent {agent_config['name']}")
for tool_name in agent_config["tools"]: for tool_name in agent_config["tools"]:
tool_config = get_tool_config_by_name(tool_configs, tool_name) tool_config = get_tool_config_by_name(tool_configs, tool_name)
if tool_config: if tool_config:
@ -195,8 +221,8 @@ def get_agents(agent_configs, tool_configs):
name=tool_name, name=tool_name,
description=tool_config["description"], description=tool_config["description"],
params_json_schema=tool_config["parameters"], params_json_schema=tool_config["parameters"],
on_invoke_tool=lambda ctx, args, _tool_name=tool_name, _tool_config=tool_config: on_invoke_tool=lambda ctx, args, _tool_name=tool_name, _tool_config=tool_config, _complete_request=complete_request:
catch_all(ctx, args, _tool_name, _tool_config) catch_all(ctx, args, _tool_name, _tool_config, _complete_request)
) )
new_tools.append(tool) new_tools.append(tool)
logger.debug(f"Added tool {tool_name} to agent {agent_config['name']}") logger.debug(f"Added tool {tool_name} to agent {agent_config['name']}")

View file

@ -0,0 +1,143 @@
from bson.objectid import ObjectId
from openai import OpenAI
import os
from motor.motor_asyncio import AsyncIOMotorClient
import asyncio
from dataclasses import dataclass
from typing import Dict, List, Any
from qdrant_client import QdrantClient
import json
# Initialize MongoDB client
mongo_uri = os.environ.get("MONGODB_URI", "mongodb://localhost:27017")
mongo_client = AsyncIOMotorClient(mongo_uri)
db = mongo_client.rowboat
data_sources_collection = db['sources']
data_source_docs_collection = db['source_docs']
qdrant_client = QdrantClient(url=os.environ.get("QDRANT_URL"))
# Initialize OpenAI client
client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
# Define embedding model
embedding_model = "text-embedding-3-small"
async def embed(model: str, value: str) -> dict:
"""
Generate embeddings using OpenAI's embedding models.
Args:
model (str): The embedding model to use (e.g., "text-embedding-3-small").
value (str): The text to embed.
Returns:
dict: A dictionary containing the embedding.
"""
response = client.embeddings.create(
model=model,
input=value
)
return {"embedding": response.data[0].embedding}
async def call_rag_tool(
project_id: str,
query: str,
source_ids: list[str],
return_type: str,
k: int,
) -> dict:
"""
Runs the RAG tool call to retrieve information based on the query and source IDs.
Args:
project_id (str): The ID of the project.
query (str): The query string to search for.
source_ids (list[str]): List of source IDs to filter the search.
return_type (str): The type of return, e.g., 'chunks' or other.
k (int): The number of results to return.
Returns:
dict: A dictionary containing the results of the search.
"""
print("\n\n calling rag tool \n\n")
print(query)
# Create embedding for the query
embed_result = await embed(model=embedding_model, value=query)
print(embed_result)
# Fetch all active data sources for this project
sources = await data_sources_collection.find({
"projectId": project_id,
"active": True
}).to_list(length=None)
print(sources)
# Filter sources to those in source_ids
valid_source_ids = [
str(s["_id"]) for s in sources if str(s["_id"]) in source_ids
]
print(valid_source_ids)
# If no valid sources are found, return empty results
if not valid_source_ids:
return ''
# Perform Qdrant vector search
qdrant_results = qdrant_client.search(
collection_name="embeddings",
query_vector=embed_result["embedding"],
query_filter={
"must": [
{"key": "projectId", "match": {"value": project_id}},
{"key": "sourceId", "match": {"any": valid_source_ids}},
]
},
limit=k,
with_payload=True
)
# Map the Qdrant results to the desired format
results = [
{
"title": point.payload["title"],
"name": point.payload["name"],
"content": point.payload["content"],
"docId": point.payload["docId"],
"sourceId": point.payload["sourceId"],
}
for point in qdrant_results
]
print(return_type)
print(results)
# If return_type is 'chunks', return the results directly
if return_type == "chunks":
return json.dumps({"Information": results}, indent=2)
# Otherwise, fetch the full document contents from MongoDB
doc_ids = [ObjectId(r["docId"]) for r in results]
docs = await data_source_docs_collection.find({"_id": {"$in": doc_ids}}).to_list(length=None)
# Create a dictionary for quick lookup of documents by their string ID
doc_dict = {str(doc["_id"]): doc for doc in docs}
# Update the results with the full document content
results = [
{**r, "content": doc_dict.get(r["docId"], {}).get("content", "")}
for r in results
]
# Convert results to a JSON string
formatted_string = json.dumps({"Information": results}, indent=2)
print(formatted_string)
return formatted_string
if __name__ == "__main__":
asyncio.run(call_rag_tool(
project_id="faf2bfb3-41d4-4299-b0d2-048581ea9bd8",
query="What is the range on your scooter",
source_ids=["67e102c9fab4514d7aaeb5a4"],
return_type="docs",
k=3))

View file

@ -81,7 +81,7 @@ if __name__ == "__main__":
"startAgent": start_agent_name "startAgent": start_agent_name
} }
print(json.dumps(request_json, indent=2)) print(json.dumps(request_json, indent=2))
print(complete_request)
resp_messages, resp_tokens_used, resp_state = run_turn( resp_messages, resp_tokens_used, resp_state = run_turn(
messages=messages, messages=messages,
start_agent_name=start_agent_name, start_agent_name=start_agent_name,
@ -89,7 +89,8 @@ if __name__ == "__main__":
tool_configs=tool_configs, tool_configs=tool_configs,
start_turn_with_start_agent=config.get("start_turn_with_start_agent", False), start_turn_with_start_agent=config.get("start_turn_with_start_agent", False),
state=state, state=state,
additional_tool_configs=[RAG_TOOL, CLOSE_CHAT_TOOL] additional_tool_configs=[RAG_TOOL, CLOSE_CHAT_TOOL],
complete_request=complete_request
) )
state = resp_state state = resp_state
resp_messages = order_messages(resp_messages) resp_messages = order_messages(resp_messages)