allow select tool at role initialization & restructure writecodewithtools

This commit is contained in:
yzlin 2024-01-20 21:06:48 +08:00
parent 2ccfe31123
commit 540542834e
8 changed files with 127 additions and 90 deletions

View file

@ -22,7 +22,8 @@ from metagpt.prompts.ml_engineer import (
TOOL_USAGE_PROMPT,
)
from metagpt.schema import Message, Plan
from metagpt.tools.tool_registry import TOOL_REGISTRY
from metagpt.tools import TOOL_REGISTRY
from metagpt.tools.tool_registry import validate_tool_names
from metagpt.utils.common import create_func_config, remove_comments
@ -90,30 +91,29 @@ class WriteCodeByGenerate(BaseWriteAnalysisCode):
class WriteCodeWithTools(BaseWriteAnalysisCode):
"""Write code with help of local available tools. Choose tools first, then generate code to use the tools"""
available_tools: dict = {}
# selected tools to choose from, listed by their names. En empty list means selection from all tools.
selected_tools: list[str] = []
def __init__(self, **kwargs):
super().__init__(**kwargs)
def _parse_recommend_tools(self, recommend_tools: list) -> dict:
def _get_tools_by_type(self, tool_type: str) -> dict:
"""
Parses and validates a list of recommended tools, and retrieves their schema from registry.
Retreive tools by tool type from registry, but filtered by pre-selected tool list
Args:
recommend_tools (list): A list of recommended tools.
tool_type (str): Tool type to retrieve from the registry
Returns:
dict: A dict of valid tool schemas.
dict: A dict of tool name to Tool object, representing available tools under the type
"""
valid_tools = []
for tool_name in recommend_tools:
if TOOL_REGISTRY.has_tool(tool_name):
valid_tools.append(TOOL_REGISTRY.get_tool(tool_name))
candidate_tools = TOOL_REGISTRY.get_tools_by_type(tool_type)
if self.selected_tools:
candidate_tools = {
tool_name: candidate_tools[tool_name]
for tool_name in self.selected_tools
if tool_name in candidate_tools
}
return candidate_tools
tool_catalog = {tool.name: tool.schemas for tool in valid_tools}
return tool_catalog
async def _tool_recommendation(
async def _recommend_tool(
self,
task: str,
code_steps: str,
@ -128,7 +128,7 @@ class WriteCodeWithTools(BaseWriteAnalysisCode):
available_tools (dict): the available tools description
Returns:
list: recommended tools for the specified task
dict: schemas of recommended tools for the specified task
"""
prompt = TOOL_RECOMMENDATION_PROMPT.format(
current_task=task,
@ -138,42 +138,62 @@ class WriteCodeWithTools(BaseWriteAnalysisCode):
tool_config = create_func_config(SELECT_FUNCTION_TOOLS)
rsp = await self.llm.aask_code(prompt, **tool_config)
recommend_tools = rsp["recommend_tools"]
return recommend_tools
logger.info(f"Recommended tools: \n{recommend_tools}")
# Parses and validates the recommended tools, for LLM might hallucinate and recommend non-existing tools
valid_tools = validate_tool_names(recommend_tools, return_tool_object=True)
tool_schemas = {tool.name: tool.schemas for tool in valid_tools}
return tool_schemas
async def _prepare_tools(self, plan: Plan) -> Tuple[dict, str]:
"""Prepare tool schemas and usage instructions according to current task
Args:
plan (Plan): The overall plan containing task information.
Returns:
Tuple[dict, str]: A tool schemas ({tool_name: tool_schema_dict}) and a usage prompt for the type of tools selected
"""
# find tool type from task type through exact match, can extend to retrieval in the future
tool_type = plan.current_task.task_type
# prepare tool-type-specific instruction
tool_type_usage_prompt = (
TOOL_REGISTRY.get_tool_type(tool_type).usage_prompt if TOOL_REGISTRY.has_tool_type(tool_type) else ""
)
# prepare schemas of available tools
tool_schemas = {}
available_tools = self._get_tools_by_type(tool_type)
if available_tools:
available_tools = {tool_name: tool.schemas["description"] for tool_name, tool in available_tools.items()}
code_steps = plan.current_task.code_steps
tool_schemas = await self._recommend_tool(plan.current_task.instruction, code_steps, available_tools)
return tool_schemas, tool_type_usage_prompt
async def run(
self,
context: List[Message],
plan: Plan = None,
plan: Plan,
**kwargs,
) -> str:
tool_type = (
plan.current_task.task_type
) # find tool type from task type through exact match, can extend to retrieval in the future
available_tools = TOOL_REGISTRY.get_tools_by_type(tool_type)
special_prompt = (
TOOL_REGISTRY.get_tool_type(tool_type).usage_prompt if TOOL_REGISTRY.has_tool_type(tool_type) else ""
# prepare tool schemas and tool-type-specific instruction
tool_schemas, tool_type_usage_prompt = await self._prepare_tools(plan=plan)
# form a complete tool usage instruction and include it as a message in context
tools_instruction = TOOL_USAGE_PROMPT.format(
tool_schemas=tool_schemas, tool_type_usage_prompt=tool_type_usage_prompt
)
code_steps = plan.current_task.code_steps
tool_catalog = {}
if available_tools:
available_tools = {tool_name: tool.schemas["description"] for tool_name, tool in available_tools.items()}
recommend_tools = await self._tool_recommendation(
plan.current_task.instruction, code_steps, available_tools
)
tool_catalog = self._parse_recommend_tools(recommend_tools)
logger.info(f"Recommended tools: \n{recommend_tools}")
tools_instruction = TOOL_USAGE_PROMPT.format(special_prompt=special_prompt, tool_catalog=tool_catalog)
context.append(Message(content=tools_instruction, role="user"))
# prepare prompt & LLM call
prompt = self.process_msg(context)
tool_config = create_func_config(CODE_GENERATOR_WITH_TOOLS)
rsp = await self.llm.aask_code(prompt, **tool_config)
return rsp
@ -185,36 +205,25 @@ class WriteCodeWithToolsML(WriteCodeWithTools):
column_info: str = "",
**kwargs,
) -> Tuple[List[Message], str]:
tool_type = (
plan.current_task.task_type
) # find tool type from task type through exact match, can extend to retrieval in the future
available_tools = TOOL_REGISTRY.get_tools_by_type(tool_type)
special_prompt = (
TOOL_REGISTRY.get_tool_type(tool_type).usage_prompt if TOOL_REGISTRY.has_tool_type(tool_type) else ""
)
code_steps = plan.current_task.code_steps
# prepare tool schemas and tool-type-specific instruction
tool_schemas, tool_type_usage_prompt = await self._prepare_tools(plan=plan)
# ML-specific variables to be used in prompt
code_steps = plan.current_task.code_steps
finished_tasks = plan.get_finished_tasks()
code_context = [remove_comments(task.code) for task in finished_tasks]
code_context = "\n\n".join(code_context)
if available_tools:
available_tools = {tool_name: tool.schemas["description"] for tool_name, tool in available_tools.items()}
recommend_tools = await self._tool_recommendation(
plan.current_task.instruction, code_steps, available_tools
)
tool_catalog = self._parse_recommend_tools(recommend_tools)
logger.info(f"Recommended tools: \n{recommend_tools}")
# prepare prompt depending on tool availability & LLM call
if tool_schemas:
prompt = ML_TOOL_USAGE_PROMPT.format(
user_requirement=plan.goal,
history_code=code_context,
current_task=plan.current_task.instruction,
column_info=column_info,
special_prompt=special_prompt,
tool_type_usage_prompt=tool_type_usage_prompt,
code_steps=code_steps,
tool_catalog=tool_catalog,
tool_schemas=tool_schemas,
)
else:
@ -223,13 +232,15 @@ class WriteCodeWithToolsML(WriteCodeWithTools):
history_code=code_context,
current_task=plan.current_task.instruction,
column_info=column_info,
special_prompt=special_prompt,
tool_type_usage_prompt=tool_type_usage_prompt,
code_steps=code_steps,
)
tool_config = create_func_config(CODE_GENERATOR_WITH_TOOLS)
rsp = await self.llm.aask_code(prompt, **tool_config)
# Extra output to be used for potential debugging
context = [Message(content=prompt, role="user")]
return context, rsp

View file

@ -161,7 +161,7 @@ Latest data info after previous tasks:
# Task
Write complete code for 'Current Task'. And avoid duplicating code from 'Done Tasks', such as repeated import of packages, reading data, etc.
Specifically, {special_prompt}
Specifically, {tool_type_usage_prompt}
# Code Steps:
Strictly follow steps below when you writing code if it's convenient.
@ -192,7 +192,7 @@ model.fit(train, y_train)
TOOL_USAGE_PROMPT = """
# Instruction
Write complete code for 'Current Task'. And avoid duplicating code from finished tasks, such as repeated import of packages, reading data, etc.
Specifically, {special_prompt}
Specifically, {tool_type_usage_prompt}
# Capabilities
- You can utilize pre-defined tools in any code lines from 'Available Tools' in the form of Python Class.
@ -200,7 +200,7 @@ Specifically, {special_prompt}
# Available Tools (can be empty):
Each Class tool is described in JSON format. When you call a tool, import the tool first.
{tool_catalog}
{tool_schemas}
# Constraints:
- Ensure the output new code is executable in the same Jupyter notebook with previous tasks code have been executed.
@ -225,7 +225,7 @@ Latest data info after previous tasks:
# Task
Write complete code for 'Current Task'. And avoid duplicating code from 'Done Tasks', such as repeated import of packages, reading data, etc.
Specifically, {special_prompt}
Specifically, {tool_type_usage_prompt}
# Code Steps:
Strictly follow steps below when you writing code if it's convenient.
@ -237,7 +237,7 @@ Strictly follow steps below when you writing code if it's convenient.
# Available Tools:
Each Class tool is described in JSON format. When you call a tool, import the tool from its path first.
{tool_catalog}
{tool_schemas}
# Output Example:
when current task is "do data preprocess, like fill missing value, handle outliers, etc.", and their are two steps in 'Code Steps', the code be like:

View file

@ -19,6 +19,7 @@ class CodeInterpreter(Role):
make_udfs: bool = False # whether to save user-defined functions
use_code_steps: bool = False
execute_code: ExecutePyCode = Field(default_factory=ExecutePyCode, exclude=True)
tools: list[str] = []
def __init__(
self,
@ -27,13 +28,20 @@ class CodeInterpreter(Role):
goal="",
auto_run=True,
use_tools=False,
make_udfs=False,
tools=[],
**kwargs,
):
super().__init__(
name=name, profile=profile, goal=goal, auto_run=auto_run, use_tools=use_tools, make_udfs=make_udfs, **kwargs
name=name, profile=profile, goal=goal, auto_run=auto_run, use_tools=use_tools, tools=tools, **kwargs
)
self._set_react_mode(react_mode="plan_and_act", auto_run=auto_run, use_tools=use_tools)
if use_tools and tools:
from metagpt.tools.tool_registry import (
validate_tool_names, # import upon use
)
self.tools = validate_tool_names(tools)
logger.info(f"will only use {self.tools} as tools")
@property
def working_memory(self):
@ -92,7 +100,7 @@ class CodeInterpreter(Role):
return code["code"], result, success
async def _write_code(self):
todo = WriteCodeByGenerate() if not self.use_tools else WriteCodeWithTools()
todo = WriteCodeByGenerate() if not self.use_tools else WriteCodeWithTools(selected_tools=self.tools)
logger.info(f"ready to {todo.name}")
context = self.planner.get_useful_memories()

View file

@ -27,7 +27,7 @@ class MLEngineer(CodeInterpreter):
column_info = await self._update_data_columns()
logger.info("Write code with tools")
tool_context, code = await WriteCodeWithToolsML().run(
tool_context, code = await WriteCodeWithToolsML(selected_tools=self.tools).run(
context=[], # context assembled inside the Action
plan=self.planner.plan,
column_info=column_info,

View file

@ -477,7 +477,7 @@ class Role(SerializationMixin, is_polymorphic_base=True):
else:
# update plan according to user's feedback and to take on changed tasks
await self.planner.update_plan(review)
await self.planner.update_plan()
completed_plan_memory = self.planner.get_useful_memories() # completed plan as a outcome

View file

@ -11,6 +11,7 @@ import re
from collections import defaultdict
import yaml
from pydantic import BaseModel
from metagpt.const import TOOL_SCHEMA_PATH
from metagpt.logs import logger
@ -18,11 +19,10 @@ from metagpt.tools.tool_convert import convert_code_to_tool_schema
from metagpt.tools.tool_data_type import Tool, ToolSchema, ToolType
class ToolRegistry:
def __init__(self):
self.tools = {}
self.tool_types = {}
self.tools_by_types = defaultdict(dict) # two-layer k-v, {tool_type: {tool_name: {...}, ...}, ...}
class ToolRegistry(BaseModel):
tools: dict = {}
tool_types: dict = {}
tools_by_types: dict = defaultdict(dict) # two-layer k-v, {tool_type: {tool_name: {...}, ...}, ...}
def register_tool_type(self, tool_type: ToolType):
self.tool_types[tool_type.name] = tool_type
@ -70,22 +70,22 @@ class ToolRegistry:
self.tools_by_types[tool_type][tool_name] = tool
logger.info(f"{tool_name} registered")
def has_tool(self, key):
def has_tool(self, key: str) -> Tool:
return key in self.tools
def get_tool(self, key):
def get_tool(self, key) -> Tool:
return self.tools.get(key)
def get_tools_by_type(self, key):
return self.tools_by_types.get(key)
def get_tools_by_type(self, key) -> dict[str, Tool]:
return self.tools_by_types.get(key, {})
def has_tool_type(self, key):
def has_tool_type(self, key) -> bool:
return key in self.tool_types
def get_tool_type(self, key):
def get_tool_type(self, key) -> ToolType:
return self.tool_types.get(key)
def get_tool_types(self):
def get_tool_types(self) -> dict[str, ToolType]:
return self.tool_types
@ -141,3 +141,16 @@ def make_schema(tool_source_object, include, path):
print(e)
return schema
def validate_tool_names(tools: list[str], return_tool_object=False) -> list[str]:
valid_tools = []
for tool_name in tools:
if not TOOL_REGISTRY.has_tool(tool_name):
logger.warning(
f"Specified tool {tool_name} not found and was skipped. Check if you have registered it properly"
)
else:
valid_tool = TOOL_REGISTRY.get_tool(tool_name) if return_tool_object else tool_name
valid_tools.append(valid_tool)
return valid_tools

View file

@ -10,7 +10,7 @@ from metagpt.utils.recovery_util import load_history, save_history
async def run_code_interpreter(
role_class, requirement, auto_run, use_tools, use_code_steps, make_udfs, use_udfs, save_dir
role_class, requirement, auto_run, use_tools, use_code_steps, make_udfs, use_udfs, save_dir, tools
):
"""
The main function to run the MLEngineer with optional history loading.
@ -25,7 +25,9 @@ async def run_code_interpreter(
"""
if role_class == "ci":
role = CodeInterpreter(goal=requirement, auto_run=auto_run, use_tools=use_tools, make_udfs=make_udfs)
role = CodeInterpreter(
goal=requirement, auto_run=auto_run, use_tools=use_tools, make_udfs=make_udfs, tools=tools
)
else:
role = MLEngineer(
goal=requirement,
@ -33,7 +35,7 @@ async def run_code_interpreter(
use_tools=use_tools,
use_code_steps=use_code_steps,
make_udfs=make_udfs,
use_udfs=use_udfs,
tools=tools,
)
if save_dir:
@ -73,6 +75,8 @@ if __name__ == "__main__":
use_tools = True
make_udfs = False
use_udfs = False
tools = []
# tools = ["FillMissingValue", "CatCross", "non_existing_test"]
async def main(
role_class: str = role_class,
@ -83,9 +87,10 @@ if __name__ == "__main__":
make_udfs: bool = make_udfs,
use_udfs: bool = use_udfs,
save_dir: str = save_dir,
tools=tools,
):
await run_code_interpreter(
role_class, requirement, auto_run, use_tools, use_code_steps, make_udfs, use_udfs, save_dir
role_class, requirement, auto_run, use_tools, use_code_steps, make_udfs, use_udfs, save_dir, tools
)
fire.Fire(main)

View file

@ -98,4 +98,4 @@ def test_get_tools_by_type(tool_registry, schema_yaml):
# Test case for when the tool type does not exist
def test_get_tools_by_nonexistent_type(tool_registry):
tools_by_type = tool_registry.get_tools_by_type("NonexistentType")
assert tools_by_type is None
assert not tools_by_type