[EXP] Make rag_search a library tool

This commit is contained in:
akhisud3195 2025-05-08 12:42:58 +05:30
parent bdc7f01ea2
commit 968dfacd65
3 changed files with 150 additions and 84 deletions

View file

@ -153,7 +153,7 @@ def get_rag_tool(config: dict, complete_request: dict) -> FunctionTool:
"""
project_id = complete_request.get("projectId", "")
if config.get("ragDataSources", None):
print("getArticleInfo")
print("rag_search")
params = {
"type": "object",
"properties": {
@ -168,7 +168,7 @@ def get_rag_tool(config: dict, complete_request: dict) -> FunctionTool:
]
}
tool = FunctionTool(
name="getArticleInfo",
name="rag_search",
description="Get information about an article",
params_json_schema=params,
on_invoke_tool=lambda ctx, args: call_rag_tool(project_id, json.loads(args)['query'], config.get("ragDataSources", []), "chunks", 3)
@ -208,11 +208,6 @@ def get_agents(agent_configs, tool_configs, complete_request):
print(f"Agent {agent_config['name']} has {len(agent_config['tools'])} configured tools")
new_tools = []
rag_tool = get_rag_tool(agent_config, complete_request)
if rag_tool:
new_tools.append(rag_tool)
print(f"Added rag tool to agent {agent_config['name']}")
agent_config = add_rag_instructions_to_agent(agent_config, rag_tool.name)
for tool_name in agent_config["tools"]:
@ -225,6 +220,10 @@ def get_agents(agent_configs, tool_configs, complete_request):
})
if tool_name == "web_search":
tool = WebSearchTool()
elif tool_name == "rag_search":
tool = get_rag_tool(agent_config, complete_request)
else:
tool = FunctionTool(
name=tool_name,
@ -234,8 +233,10 @@ def get_agents(agent_configs, tool_configs, complete_request):
on_invoke_tool=lambda ctx, args, _tool_name=tool_name, _tool_config=tool_config, _complete_request=complete_request:
catch_all(ctx, args, _tool_name, _tool_config, _complete_request)
)
new_tools.append(tool)
print(f"Added tool {tool_name} to agent {agent_config['name']}")
if tool:
new_tools.append(tool)
logger.debug(f"Added tool {tool_name} to agent {agent_config['name']}")
print(f"Added tool {tool_name} to agent {agent_config['name']}")
else:
print(f"WARNING: Tool {tool_name} not found in tool_configs")