fix unit tests for tool module

This commit is contained in:
yzlin 2024-03-11 16:18:28 +08:00
parent bff3ef02bc
commit b5af9ccde6
7 changed files with 91 additions and 141 deletions

View file

@ -42,9 +42,7 @@ class WritePlan(Action):
"""
async def run(self, context: list[Message], max_tasks: int = 5, use_tools: bool = False) -> str:
task_type_desc = "\n".join(
[f"- **{tt.type_name}**: {tt.value.desc}" for tt in TaskType]
) # task type are binded with tool type now, should be improved in the future
task_type_desc = "\n".join([f"- **{tt.type_name}**: {tt.value.desc}" for tt in TaskType])
prompt = self.PROMPT_TEMPLATE.format(
context="\n".join([str(ct) for ct in context]), max_tasks=max_tasks, task_type_desc=task_type_desc
)

View file

@ -156,11 +156,16 @@ class Interpreter(Role):
return code, todo
async def _check_data(self):
if not self.use_plan or self.planner.plan.current_task.task_type not in [
TaskType.DATA_PREPROCESS.type_name,
TaskType.FEATURE_ENGINEERING.type_name,
TaskType.MODEL_TRAIN.type_name,
]:
if (
not self.use_plan
or not self.planner.plan.get_finished_tasks()
or self.planner.plan.current_task.task_type
not in [
TaskType.DATA_PREPROCESS.type_name,
TaskType.FEATURE_ENGINEERING.type_name,
TaskType.MODEL_TRAIN.type_name,
]
):
return
logger.info("Check updated data")
code = await CheckData().run(self.planner.plan)

View file

@ -174,9 +174,10 @@ class BM25ToolRecommender(ToolRecommender):
doc_scores = self.bm25.get_scores(query_tokens)
top_indexes = np.argsort(doc_scores)[::-1][:topk]
recalled_tools = [list(self.tools.values())[index] for index in top_indexes]
print([doc_scores[index] for index in top_indexes])
print([recalled_tools[i].name for i in range(len(recalled_tools))])
print([recalled_tools[i].schemas["description"] for i in range(len(recalled_tools))])
logger.info(
f"Recalled tools: \n{[tool.name for tool in recalled_tools]}; Scores: {[doc_scores[index] for index in top_indexes]}"
)
return recalled_tools

View file

@ -134,7 +134,7 @@ def validate_tool_names(tools: Union[list[str], str]) -> str:
# one can define either tool names or tool type names, take union to get the whole set
if TOOL_REGISTRY.has_tool(key):
valid_tools.update({key: TOOL_REGISTRY.get_tool(key)})
elif TOOL_REGISTRY.tool_tool_tag(key):
elif TOOL_REGISTRY.has_tool_tag(key):
valid_tools.update(TOOL_REGISTRY.get_tools_by_tag(key))
else:
logger.warning(f"invalid tool name or tool type name: {key}, skipped")