From b81fefffa17233ff0654395841e8d5bdd604a225 Mon Sep 17 00:00:00 2001 From: lidanyang Date: Thu, 30 Nov 2023 16:28:02 +0800 Subject: [PATCH] avoid repetitive tool desc between steps --- metagpt/actions/write_analysis_code.py | 22 +++++++++++++++------- metagpt/prompts/ml_engineer.py | 6 +++++- 2 files changed, 20 insertions(+), 8 deletions(-) diff --git a/metagpt/actions/write_analysis_code.py b/metagpt/actions/write_analysis_code.py index 787fb8d3e..6fff1c66f 100644 --- a/metagpt/actions/write_analysis_code.py +++ b/metagpt/actions/write_analysis_code.py @@ -5,7 +5,7 @@ @File : write_code_v2.py """ import json -from typing import Dict, List, Union +from typing import Dict, List, Union, Tuple from metagpt.actions import Action from metagpt.prompts.ml_engineer import ( @@ -100,24 +100,31 @@ class WriteCodeWithTools(BaseWriteAnalysisCode): """Write code with help of local available tools. Choose tools first, then generate code to use the tools""" @staticmethod - def _parse_recommend_tools(module: str, recommend_tools: list) -> str: + def _parse_recommend_tools(module: str, recommend_tools: list) -> Tuple[Dict, List[Dict]]: """ - Converts recommended tools to a JSON string and checks tool availability in the registry. + Parses and validates a list of recommended tools, and retrieves their schema from registry. Args: module (str): The module name for querying tools in the registry. recommend_tools (list): A list of lists of recommended tools for each step. Returns: - str: A JSON string with available tools and their schemas for each step. + Tuple[Dict, List[Dict]]: + - valid_tools: A dict of lists of valid tools for each step. + - tool_catalog: A list of dicts of unique tool schemas. """ valid_tools = {} available_tools = registry.get_all_by_module(module).keys() for index, tools in enumerate(recommend_tools): key = f"Step {index + 1}" tools = [tool for tool in tools if tool in available_tools] - valid_tools[key] = registry.get_schemas(module, tools) - return json.dumps(valid_tools) + valid_tools[key] = tools + + unique_tools = set() + for tools in valid_tools.values(): + unique_tools.update(tools) + tool_catalog = registry.get_schemas(module, unique_tools) + return valid_tools, tool_catalog async def _tool_recommendation( self, task: str, data_desc: str, code_steps: str, available_tools: list @@ -166,7 +173,7 @@ class WriteCodeWithTools(BaseWriteAnalysisCode): recommend_tools = await self._tool_recommendation( task, task_guide, available_tools ) - recommend_tools = self._parse_recommend_tools(task_type, recommend_tools) + recommend_tools, tool_catalog = self._parse_recommend_tools(task_type, recommend_tools) special_prompt = ML_SPECIFIC_PROMPT.get(task_type, "") module_name = ML_MODULE_MAP[task_type] @@ -191,6 +198,7 @@ class WriteCodeWithTools(BaseWriteAnalysisCode): module_name=module_name, output_desc=output_desc, available_tools=recommend_tools, + tool_catalog=tool_catalog, ) tool_config = create_func_config(CODE_GENERATOR_WITH_TOOLS) rsp = await self.llm.aask_code(prompt, **tool_config) diff --git a/metagpt/prompts/ml_engineer.py b/metagpt/prompts/ml_engineer.py index 55ac27d82..70a40ef34 100644 --- a/metagpt/prompts/ml_engineer.py +++ b/metagpt/prompts/ml_engineer.py @@ -95,9 +95,13 @@ from metagpt.tools.functions.libs.feature_engineering import fill_missing_value ``` ## Available Functions for Each Step: -Each function is described in JSON format, including the function name and parameters. {output_desc} +Here's a list of all available functions for each step. You can find more details about each function in [## Function Catalog] {available_tools} +## Function Catalog: +Each function is described in JSON format, including the function name and parameters. {output_desc} +{function_catalog} + ## Your Output Format: Generate the complete code for every step, listing any used function tools at the beginning of the step: ```python