fix unit tests for tool module

This commit is contained in:
yzlin 2024-03-11 16:18:28 +08:00
parent bff3ef02bc
commit b5af9ccde6
7 changed files with 91 additions and 141 deletions

View file

@ -174,9 +174,10 @@ class BM25ToolRecommender(ToolRecommender):
doc_scores = self.bm25.get_scores(query_tokens)
top_indexes = np.argsort(doc_scores)[::-1][:topk]
recalled_tools = [list(self.tools.values())[index] for index in top_indexes]
print([doc_scores[index] for index in top_indexes])
print([recalled_tools[i].name for i in range(len(recalled_tools))])
print([recalled_tools[i].schemas["description"] for i in range(len(recalled_tools))])
logger.info(
f"Recalled tools: \n{[tool.name for tool in recalled_tools]}; Scores: {[doc_scores[index] for index in top_indexes]}"
)
return recalled_tools

View file

@ -134,7 +134,7 @@ def validate_tool_names(tools: Union[list[str], str]) -> str:
# one can define either tool names or tool type names, take union to get the whole set
if TOOL_REGISTRY.has_tool(key):
valid_tools.update({key: TOOL_REGISTRY.get_tool(key)})
elif TOOL_REGISTRY.tool_tool_tag(key):
elif TOOL_REGISTRY.has_tool_tag(key):
valid_tools.update(TOOL_REGISTRY.get_tools_by_tag(key))
else:
logger.warning(f"invalid tool name or tool type name: {key}, skipped")