mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-06-02 14:45:17 +02:00
fix unit tests for tool module
This commit is contained in:
parent
bff3ef02bc
commit
b5af9ccde6
7 changed files with 91 additions and 141 deletions
|
|
@ -42,9 +42,7 @@ class WritePlan(Action):
|
|||
"""
|
||||
|
||||
async def run(self, context: list[Message], max_tasks: int = 5, use_tools: bool = False) -> str:
|
||||
task_type_desc = "\n".join(
|
||||
[f"- **{tt.type_name}**: {tt.value.desc}" for tt in TaskType]
|
||||
) # task type are binded with tool type now, should be improved in the future
|
||||
task_type_desc = "\n".join([f"- **{tt.type_name}**: {tt.value.desc}" for tt in TaskType])
|
||||
prompt = self.PROMPT_TEMPLATE.format(
|
||||
context="\n".join([str(ct) for ct in context]), max_tasks=max_tasks, task_type_desc=task_type_desc
|
||||
)
|
||||
|
|
|
|||
|
|
@ -156,11 +156,16 @@ class Interpreter(Role):
|
|||
return code, todo
|
||||
|
||||
async def _check_data(self):
|
||||
if not self.use_plan or self.planner.plan.current_task.task_type not in [
|
||||
TaskType.DATA_PREPROCESS.type_name,
|
||||
TaskType.FEATURE_ENGINEERING.type_name,
|
||||
TaskType.MODEL_TRAIN.type_name,
|
||||
]:
|
||||
if (
|
||||
not self.use_plan
|
||||
or not self.planner.plan.get_finished_tasks()
|
||||
or self.planner.plan.current_task.task_type
|
||||
not in [
|
||||
TaskType.DATA_PREPROCESS.type_name,
|
||||
TaskType.FEATURE_ENGINEERING.type_name,
|
||||
TaskType.MODEL_TRAIN.type_name,
|
||||
]
|
||||
):
|
||||
return
|
||||
logger.info("Check updated data")
|
||||
code = await CheckData().run(self.planner.plan)
|
||||
|
|
|
|||
|
|
@ -174,9 +174,10 @@ class BM25ToolRecommender(ToolRecommender):
|
|||
doc_scores = self.bm25.get_scores(query_tokens)
|
||||
top_indexes = np.argsort(doc_scores)[::-1][:topk]
|
||||
recalled_tools = [list(self.tools.values())[index] for index in top_indexes]
|
||||
print([doc_scores[index] for index in top_indexes])
|
||||
print([recalled_tools[i].name for i in range(len(recalled_tools))])
|
||||
print([recalled_tools[i].schemas["description"] for i in range(len(recalled_tools))])
|
||||
|
||||
logger.info(
|
||||
f"Recalled tools: \n{[tool.name for tool in recalled_tools]}; Scores: {[doc_scores[index] for index in top_indexes]}"
|
||||
)
|
||||
|
||||
return recalled_tools
|
||||
|
||||
|
|
|
|||
|
|
@ -134,7 +134,7 @@ def validate_tool_names(tools: Union[list[str], str]) -> str:
|
|||
# one can define either tool names or tool type names, take union to get the whole set
|
||||
if TOOL_REGISTRY.has_tool(key):
|
||||
valid_tools.update({key: TOOL_REGISTRY.get_tool(key)})
|
||||
elif TOOL_REGISTRY.tool_tool_tag(key):
|
||||
elif TOOL_REGISTRY.has_tool_tag(key):
|
||||
valid_tools.update(TOOL_REGISTRY.get_tools_by_tag(key))
|
||||
else:
|
||||
logger.warning(f"invalid tool name or tool type name: {key}, skipped")
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue