From 540542834ebafb0043503a7860e5b382d46b47cf Mon Sep 17 00:00:00 2001 From: yzlin Date: Sat, 20 Jan 2024 21:06:48 +0800 Subject: [PATCH] allow select tool at role initialization & restructure writecodewithtools --- metagpt/actions/write_analysis_code.py | 137 +++++++++++--------- metagpt/prompts/ml_engineer.py | 10 +- metagpt/roles/code_interpreter.py | 14 +- metagpt/roles/ml_engineer.py | 2 +- metagpt/roles/role.py | 2 +- metagpt/tools/tool_registry.py | 37 ++++-- tests/metagpt/roles/run_code_interpreter.py | 13 +- tests/metagpt/tools/test_tool_registry.py | 2 +- 8 files changed, 127 insertions(+), 90 deletions(-) diff --git a/metagpt/actions/write_analysis_code.py b/metagpt/actions/write_analysis_code.py index cf806a986..c6e504b9e 100644 --- a/metagpt/actions/write_analysis_code.py +++ b/metagpt/actions/write_analysis_code.py @@ -22,7 +22,8 @@ from metagpt.prompts.ml_engineer import ( TOOL_USAGE_PROMPT, ) from metagpt.schema import Message, Plan -from metagpt.tools.tool_registry import TOOL_REGISTRY +from metagpt.tools import TOOL_REGISTRY +from metagpt.tools.tool_registry import validate_tool_names from metagpt.utils.common import create_func_config, remove_comments @@ -90,30 +91,29 @@ class WriteCodeByGenerate(BaseWriteAnalysisCode): class WriteCodeWithTools(BaseWriteAnalysisCode): """Write code with help of local available tools. Choose tools first, then generate code to use the tools""" - available_tools: dict = {} + # selected tools to choose from, listed by their names. En empty list means selection from all tools. + selected_tools: list[str] = [] - def __init__(self, **kwargs): - super().__init__(**kwargs) - - def _parse_recommend_tools(self, recommend_tools: list) -> dict: + def _get_tools_by_type(self, tool_type: str) -> dict: """ - Parses and validates a list of recommended tools, and retrieves their schema from registry. + Retreive tools by tool type from registry, but filtered by pre-selected tool list Args: - recommend_tools (list): A list of recommended tools. + tool_type (str): Tool type to retrieve from the registry Returns: - dict: A dict of valid tool schemas. + dict: A dict of tool name to Tool object, representing available tools under the type """ - valid_tools = [] - for tool_name in recommend_tools: - if TOOL_REGISTRY.has_tool(tool_name): - valid_tools.append(TOOL_REGISTRY.get_tool(tool_name)) + candidate_tools = TOOL_REGISTRY.get_tools_by_type(tool_type) + if self.selected_tools: + candidate_tools = { + tool_name: candidate_tools[tool_name] + for tool_name in self.selected_tools + if tool_name in candidate_tools + } + return candidate_tools - tool_catalog = {tool.name: tool.schemas for tool in valid_tools} - return tool_catalog - - async def _tool_recommendation( + async def _recommend_tool( self, task: str, code_steps: str, @@ -128,7 +128,7 @@ class WriteCodeWithTools(BaseWriteAnalysisCode): available_tools (dict): the available tools description Returns: - list: recommended tools for the specified task + dict: schemas of recommended tools for the specified task """ prompt = TOOL_RECOMMENDATION_PROMPT.format( current_task=task, @@ -138,42 +138,62 @@ class WriteCodeWithTools(BaseWriteAnalysisCode): tool_config = create_func_config(SELECT_FUNCTION_TOOLS) rsp = await self.llm.aask_code(prompt, **tool_config) recommend_tools = rsp["recommend_tools"] - return recommend_tools + logger.info(f"Recommended tools: \n{recommend_tools}") + + # Parses and validates the recommended tools, for LLM might hallucinate and recommend non-existing tools + valid_tools = validate_tool_names(recommend_tools, return_tool_object=True) + + tool_schemas = {tool.name: tool.schemas for tool in valid_tools} + + return tool_schemas + + async def _prepare_tools(self, plan: Plan) -> Tuple[dict, str]: + """Prepare tool schemas and usage instructions according to current task + + Args: + plan (Plan): The overall plan containing task information. + + Returns: + Tuple[dict, str]: A tool schemas ({tool_name: tool_schema_dict}) and a usage prompt for the type of tools selected + """ + # find tool type from task type through exact match, can extend to retrieval in the future + tool_type = plan.current_task.task_type + + # prepare tool-type-specific instruction + tool_type_usage_prompt = ( + TOOL_REGISTRY.get_tool_type(tool_type).usage_prompt if TOOL_REGISTRY.has_tool_type(tool_type) else "" + ) + + # prepare schemas of available tools + tool_schemas = {} + available_tools = self._get_tools_by_type(tool_type) + if available_tools: + available_tools = {tool_name: tool.schemas["description"] for tool_name, tool in available_tools.items()} + code_steps = plan.current_task.code_steps + tool_schemas = await self._recommend_tool(plan.current_task.instruction, code_steps, available_tools) + + return tool_schemas, tool_type_usage_prompt async def run( self, context: List[Message], - plan: Plan = None, + plan: Plan, **kwargs, ) -> str: - tool_type = ( - plan.current_task.task_type - ) # find tool type from task type through exact match, can extend to retrieval in the future - available_tools = TOOL_REGISTRY.get_tools_by_type(tool_type) - special_prompt = ( - TOOL_REGISTRY.get_tool_type(tool_type).usage_prompt if TOOL_REGISTRY.has_tool_type(tool_type) else "" + # prepare tool schemas and tool-type-specific instruction + tool_schemas, tool_type_usage_prompt = await self._prepare_tools(plan=plan) + + # form a complete tool usage instruction and include it as a message in context + tools_instruction = TOOL_USAGE_PROMPT.format( + tool_schemas=tool_schemas, tool_type_usage_prompt=tool_type_usage_prompt ) - code_steps = plan.current_task.code_steps - - tool_catalog = {} - - if available_tools: - available_tools = {tool_name: tool.schemas["description"] for tool_name, tool in available_tools.items()} - - recommend_tools = await self._tool_recommendation( - plan.current_task.instruction, code_steps, available_tools - ) - tool_catalog = self._parse_recommend_tools(recommend_tools) - logger.info(f"Recommended tools: \n{recommend_tools}") - - tools_instruction = TOOL_USAGE_PROMPT.format(special_prompt=special_prompt, tool_catalog=tool_catalog) - context.append(Message(content=tools_instruction, role="user")) + # prepare prompt & LLM call prompt = self.process_msg(context) - tool_config = create_func_config(CODE_GENERATOR_WITH_TOOLS) rsp = await self.llm.aask_code(prompt, **tool_config) + return rsp @@ -185,36 +205,25 @@ class WriteCodeWithToolsML(WriteCodeWithTools): column_info: str = "", **kwargs, ) -> Tuple[List[Message], str]: - tool_type = ( - plan.current_task.task_type - ) # find tool type from task type through exact match, can extend to retrieval in the future - available_tools = TOOL_REGISTRY.get_tools_by_type(tool_type) - special_prompt = ( - TOOL_REGISTRY.get_tool_type(tool_type).usage_prompt if TOOL_REGISTRY.has_tool_type(tool_type) else "" - ) - code_steps = plan.current_task.code_steps + # prepare tool schemas and tool-type-specific instruction + tool_schemas, tool_type_usage_prompt = await self._prepare_tools(plan=plan) + # ML-specific variables to be used in prompt + code_steps = plan.current_task.code_steps finished_tasks = plan.get_finished_tasks() code_context = [remove_comments(task.code) for task in finished_tasks] code_context = "\n\n".join(code_context) - if available_tools: - available_tools = {tool_name: tool.schemas["description"] for tool_name, tool in available_tools.items()} - - recommend_tools = await self._tool_recommendation( - plan.current_task.instruction, code_steps, available_tools - ) - tool_catalog = self._parse_recommend_tools(recommend_tools) - logger.info(f"Recommended tools: \n{recommend_tools}") - + # prepare prompt depending on tool availability & LLM call + if tool_schemas: prompt = ML_TOOL_USAGE_PROMPT.format( user_requirement=plan.goal, history_code=code_context, current_task=plan.current_task.instruction, column_info=column_info, - special_prompt=special_prompt, + tool_type_usage_prompt=tool_type_usage_prompt, code_steps=code_steps, - tool_catalog=tool_catalog, + tool_schemas=tool_schemas, ) else: @@ -223,13 +232,15 @@ class WriteCodeWithToolsML(WriteCodeWithTools): history_code=code_context, current_task=plan.current_task.instruction, column_info=column_info, - special_prompt=special_prompt, + tool_type_usage_prompt=tool_type_usage_prompt, code_steps=code_steps, ) - tool_config = create_func_config(CODE_GENERATOR_WITH_TOOLS) rsp = await self.llm.aask_code(prompt, **tool_config) + + # Extra output to be used for potential debugging context = [Message(content=prompt, role="user")] + return context, rsp diff --git a/metagpt/prompts/ml_engineer.py b/metagpt/prompts/ml_engineer.py index 3fd895e6e..ac95e14bd 100644 --- a/metagpt/prompts/ml_engineer.py +++ b/metagpt/prompts/ml_engineer.py @@ -161,7 +161,7 @@ Latest data info after previous tasks: # Task Write complete code for 'Current Task'. And avoid duplicating code from 'Done Tasks', such as repeated import of packages, reading data, etc. -Specifically, {special_prompt} +Specifically, {tool_type_usage_prompt} # Code Steps: Strictly follow steps below when you writing code if it's convenient. @@ -192,7 +192,7 @@ model.fit(train, y_train) TOOL_USAGE_PROMPT = """ # Instruction Write complete code for 'Current Task'. And avoid duplicating code from finished tasks, such as repeated import of packages, reading data, etc. -Specifically, {special_prompt} +Specifically, {tool_type_usage_prompt} # Capabilities - You can utilize pre-defined tools in any code lines from 'Available Tools' in the form of Python Class. @@ -200,7 +200,7 @@ Specifically, {special_prompt} # Available Tools (can be empty): Each Class tool is described in JSON format. When you call a tool, import the tool first. -{tool_catalog} +{tool_schemas} # Constraints: - Ensure the output new code is executable in the same Jupyter notebook with previous tasks code have been executed. @@ -225,7 +225,7 @@ Latest data info after previous tasks: # Task Write complete code for 'Current Task'. And avoid duplicating code from 'Done Tasks', such as repeated import of packages, reading data, etc. -Specifically, {special_prompt} +Specifically, {tool_type_usage_prompt} # Code Steps: Strictly follow steps below when you writing code if it's convenient. @@ -237,7 +237,7 @@ Strictly follow steps below when you writing code if it's convenient. # Available Tools: Each Class tool is described in JSON format. When you call a tool, import the tool from its path first. -{tool_catalog} +{tool_schemas} # Output Example: when current task is "do data preprocess, like fill missing value, handle outliers, etc.", and their are two steps in 'Code Steps', the code be like: diff --git a/metagpt/roles/code_interpreter.py b/metagpt/roles/code_interpreter.py index f972e72e2..11ede6068 100644 --- a/metagpt/roles/code_interpreter.py +++ b/metagpt/roles/code_interpreter.py @@ -19,6 +19,7 @@ class CodeInterpreter(Role): make_udfs: bool = False # whether to save user-defined functions use_code_steps: bool = False execute_code: ExecutePyCode = Field(default_factory=ExecutePyCode, exclude=True) + tools: list[str] = [] def __init__( self, @@ -27,13 +28,20 @@ class CodeInterpreter(Role): goal="", auto_run=True, use_tools=False, - make_udfs=False, + tools=[], **kwargs, ): super().__init__( - name=name, profile=profile, goal=goal, auto_run=auto_run, use_tools=use_tools, make_udfs=make_udfs, **kwargs + name=name, profile=profile, goal=goal, auto_run=auto_run, use_tools=use_tools, tools=tools, **kwargs ) self._set_react_mode(react_mode="plan_and_act", auto_run=auto_run, use_tools=use_tools) + if use_tools and tools: + from metagpt.tools.tool_registry import ( + validate_tool_names, # import upon use + ) + + self.tools = validate_tool_names(tools) + logger.info(f"will only use {self.tools} as tools") @property def working_memory(self): @@ -92,7 +100,7 @@ class CodeInterpreter(Role): return code["code"], result, success async def _write_code(self): - todo = WriteCodeByGenerate() if not self.use_tools else WriteCodeWithTools() + todo = WriteCodeByGenerate() if not self.use_tools else WriteCodeWithTools(selected_tools=self.tools) logger.info(f"ready to {todo.name}") context = self.planner.get_useful_memories() diff --git a/metagpt/roles/ml_engineer.py b/metagpt/roles/ml_engineer.py index 6b671f9c2..d1a22b9d3 100644 --- a/metagpt/roles/ml_engineer.py +++ b/metagpt/roles/ml_engineer.py @@ -27,7 +27,7 @@ class MLEngineer(CodeInterpreter): column_info = await self._update_data_columns() logger.info("Write code with tools") - tool_context, code = await WriteCodeWithToolsML().run( + tool_context, code = await WriteCodeWithToolsML(selected_tools=self.tools).run( context=[], # context assembled inside the Action plan=self.planner.plan, column_info=column_info, diff --git a/metagpt/roles/role.py b/metagpt/roles/role.py index a2f2f2e9d..21e48a127 100644 --- a/metagpt/roles/role.py +++ b/metagpt/roles/role.py @@ -477,7 +477,7 @@ class Role(SerializationMixin, is_polymorphic_base=True): else: # update plan according to user's feedback and to take on changed tasks - await self.planner.update_plan(review) + await self.planner.update_plan() completed_plan_memory = self.planner.get_useful_memories() # completed plan as a outcome diff --git a/metagpt/tools/tool_registry.py b/metagpt/tools/tool_registry.py index fbdfb3cfd..c064a19de 100644 --- a/metagpt/tools/tool_registry.py +++ b/metagpt/tools/tool_registry.py @@ -11,6 +11,7 @@ import re from collections import defaultdict import yaml +from pydantic import BaseModel from metagpt.const import TOOL_SCHEMA_PATH from metagpt.logs import logger @@ -18,11 +19,10 @@ from metagpt.tools.tool_convert import convert_code_to_tool_schema from metagpt.tools.tool_data_type import Tool, ToolSchema, ToolType -class ToolRegistry: - def __init__(self): - self.tools = {} - self.tool_types = {} - self.tools_by_types = defaultdict(dict) # two-layer k-v, {tool_type: {tool_name: {...}, ...}, ...} +class ToolRegistry(BaseModel): + tools: dict = {} + tool_types: dict = {} + tools_by_types: dict = defaultdict(dict) # two-layer k-v, {tool_type: {tool_name: {...}, ...}, ...} def register_tool_type(self, tool_type: ToolType): self.tool_types[tool_type.name] = tool_type @@ -70,22 +70,22 @@ class ToolRegistry: self.tools_by_types[tool_type][tool_name] = tool logger.info(f"{tool_name} registered") - def has_tool(self, key): + def has_tool(self, key: str) -> Tool: return key in self.tools - def get_tool(self, key): + def get_tool(self, key) -> Tool: return self.tools.get(key) - def get_tools_by_type(self, key): - return self.tools_by_types.get(key) + def get_tools_by_type(self, key) -> dict[str, Tool]: + return self.tools_by_types.get(key, {}) - def has_tool_type(self, key): + def has_tool_type(self, key) -> bool: return key in self.tool_types - def get_tool_type(self, key): + def get_tool_type(self, key) -> ToolType: return self.tool_types.get(key) - def get_tool_types(self): + def get_tool_types(self) -> dict[str, ToolType]: return self.tool_types @@ -141,3 +141,16 @@ def make_schema(tool_source_object, include, path): print(e) return schema + + +def validate_tool_names(tools: list[str], return_tool_object=False) -> list[str]: + valid_tools = [] + for tool_name in tools: + if not TOOL_REGISTRY.has_tool(tool_name): + logger.warning( + f"Specified tool {tool_name} not found and was skipped. Check if you have registered it properly" + ) + else: + valid_tool = TOOL_REGISTRY.get_tool(tool_name) if return_tool_object else tool_name + valid_tools.append(valid_tool) + return valid_tools diff --git a/tests/metagpt/roles/run_code_interpreter.py b/tests/metagpt/roles/run_code_interpreter.py index 539b20286..766a25998 100644 --- a/tests/metagpt/roles/run_code_interpreter.py +++ b/tests/metagpt/roles/run_code_interpreter.py @@ -10,7 +10,7 @@ from metagpt.utils.recovery_util import load_history, save_history async def run_code_interpreter( - role_class, requirement, auto_run, use_tools, use_code_steps, make_udfs, use_udfs, save_dir + role_class, requirement, auto_run, use_tools, use_code_steps, make_udfs, use_udfs, save_dir, tools ): """ The main function to run the MLEngineer with optional history loading. @@ -25,7 +25,9 @@ async def run_code_interpreter( """ if role_class == "ci": - role = CodeInterpreter(goal=requirement, auto_run=auto_run, use_tools=use_tools, make_udfs=make_udfs) + role = CodeInterpreter( + goal=requirement, auto_run=auto_run, use_tools=use_tools, make_udfs=make_udfs, tools=tools + ) else: role = MLEngineer( goal=requirement, @@ -33,7 +35,7 @@ async def run_code_interpreter( use_tools=use_tools, use_code_steps=use_code_steps, make_udfs=make_udfs, - use_udfs=use_udfs, + tools=tools, ) if save_dir: @@ -73,6 +75,8 @@ if __name__ == "__main__": use_tools = True make_udfs = False use_udfs = False + tools = [] + # tools = ["FillMissingValue", "CatCross", "non_existing_test"] async def main( role_class: str = role_class, @@ -83,9 +87,10 @@ if __name__ == "__main__": make_udfs: bool = make_udfs, use_udfs: bool = use_udfs, save_dir: str = save_dir, + tools=tools, ): await run_code_interpreter( - role_class, requirement, auto_run, use_tools, use_code_steps, make_udfs, use_udfs, save_dir + role_class, requirement, auto_run, use_tools, use_code_steps, make_udfs, use_udfs, save_dir, tools ) fire.Fire(main) diff --git a/tests/metagpt/tools/test_tool_registry.py b/tests/metagpt/tools/test_tool_registry.py index 582c368a8..c24122e39 100644 --- a/tests/metagpt/tools/test_tool_registry.py +++ b/tests/metagpt/tools/test_tool_registry.py @@ -98,4 +98,4 @@ def test_get_tools_by_type(tool_registry, schema_yaml): # Test case for when the tool type does not exist def test_get_tools_by_nonexistent_type(tool_registry): tools_by_type = tool_registry.get_tools_by_type("NonexistentType") - assert tools_by_type is None + assert not tools_by_type