add tool registry

This commit is contained in:
yzlin 2024-01-13 01:28:49 +08:00
parent 224bf820b2
commit 46cd219e81
25 changed files with 1582 additions and 59 deletions

View file

@ -8,11 +8,9 @@ import re
from pathlib import Path
from typing import Dict, List, Tuple, Union
import yaml
from tenacity import retry, stop_after_attempt, wait_fixed
from metagpt.actions import Action
from metagpt.const import TOOL_SCHEMA_PATH
from metagpt.llm import LLM
from metagpt.logs import logger
from metagpt.prompts.ml_engineer import (
@ -24,12 +22,9 @@ from metagpt.prompts.ml_engineer import (
TOOL_USAGE_PROMPT,
)
from metagpt.schema import Message, Plan
from metagpt.tools import TOOL_TYPE_MAPPINGS
from metagpt.tools.tool_registry import TOOL_REGISTRY
from metagpt.utils.common import create_func_config, remove_comments
TOOL_TYPE_MODULE = {k: v.module for k, v in TOOL_TYPE_MAPPINGS.items()}
TOOL_TYPE_USAGE_PROMPT = {k: v.usage_prompt for k, v in TOOL_TYPE_MAPPINGS.items()}
class BaseWriteAnalysisCode(Action):
DEFAULT_SYSTEM_MSG: str = """You are Code Interpreter, a world-class programmer that can complete any goal by executing code. Strictly follow the plan and generate code step by step. Each step of the code will be executed on the user's machine, and the user will provide the code execution results to you.**Notice: The code for the next step depends on the code for the previous step. Must reuse variables in the lastest other code directly, dont creat it again, it is very import for you. Use !pip install in a standalone block to install missing packages.Usually the libraries you need are already installed.Dont check if packages already imported.**""" # prompt reference: https://github.com/KillianLucas/open-interpreter/blob/v0.1.4/interpreter/system_message.txt
@ -95,49 +90,27 @@ class WriteCodeByGenerate(BaseWriteAnalysisCode):
class WriteCodeWithTools(BaseWriteAnalysisCode):
"""Write code with help of local available tools. Choose tools first, then generate code to use the tools"""
schema_path: Union[Path, str] = TOOL_SCHEMA_PATH
available_tools: dict = {}
def __init__(self, **kwargs):
super().__init__(**kwargs)
self._load_tools(self.schema_path)
def _load_tools(self, schema_path, schema_module=None):
"""Load tools from yaml file"""
if isinstance(schema_path, dict):
schema_module = schema_module or "udf"
self.available_tools.update({schema_module: schema_path})
else:
if isinstance(schema_path, list):
yml_files = schema_path
elif isinstance(schema_path, Path) and schema_path.is_file():
yml_files = [schema_path]
else:
yml_files = schema_path.glob("*.yml")
for yml_file in yml_files:
module = yml_file.stem
with open(yml_file, "r", encoding="utf-8") as f:
self.available_tools[module] = yaml.safe_load(f)
def _parse_recommend_tools(self, module: str, recommend_tools: list) -> dict:
def _parse_recommend_tools(self, recommend_tools: list) -> dict:
"""
Parses and validates a list of recommended tools, and retrieves their schema from registry.
Args:
module (str): The module name for querying tools in the registry.
recommend_tools (list): A list of recommended tools.
Returns:
dict: A dict of valid tool schemas.
"""
valid_tools = []
available_tools = self.available_tools[module].keys()
for tool in recommend_tools:
if tool in available_tools:
valid_tools.append(tool)
for tool_name in recommend_tools:
if TOOL_REGISTRY.has_tool(tool_name):
valid_tools.append(TOOL_REGISTRY.get_tool(tool_name))
tool_catalog = {tool: self.available_tools[module][tool] for tool in valid_tools}
tool_catalog = {tool.name: tool.schema for tool in valid_tools}
return tool_catalog
async def _tool_recommendation(
@ -176,8 +149,10 @@ class WriteCodeWithTools(BaseWriteAnalysisCode):
tool_type = (
plan.current_task.task_type
) # find tool type from task type through exact match, can extend to retrieval in the future
available_tools = self.available_tools.get(tool_type, {})
special_prompt = TOOL_TYPE_USAGE_PROMPT.get(tool_type, "")
available_tools = TOOL_REGISTRY.get_tools_by_type(tool_type)
special_prompt = (
TOOL_REGISTRY.get_tool_type(tool_type).usage_prompt if TOOL_REGISTRY.has_tool_type(tool_type) else ""
)
code_steps = plan.current_task.code_steps
finished_tasks = plan.get_finished_tasks()
@ -185,22 +160,17 @@ class WriteCodeWithTools(BaseWriteAnalysisCode):
code_context = "\n\n".join(code_context)
tool_catalog = {}
module_name = ""
if len(available_tools) > 0:
available_tools = {k: v["description"] for k, v in available_tools.items()}
if available_tools:
available_tools = {tool_name: tool.schema["description"] for tool_name, tool in available_tools.items()}
recommend_tools = await self._tool_recommendation(
plan.current_task.instruction, code_steps, available_tools
)
tool_catalog = self._parse_recommend_tools(tool_type, recommend_tools)
tool_catalog = self._parse_recommend_tools(recommend_tools)
logger.info(f"Recommended tools: \n{recommend_tools}")
module_name = TOOL_TYPE_MODULE[tool_type]
tools_instruction = TOOL_USAGE_PROMPT.format(
special_prompt=special_prompt, module_name=module_name, tool_catalog=tool_catalog
)
tools_instruction = TOOL_USAGE_PROMPT.format(special_prompt=special_prompt, tool_catalog=tool_catalog)
context.append(Message(content=tools_instruction, role="user"))

View file

@ -12,7 +12,7 @@ from metagpt.actions import Action
from metagpt.logs import logger
from metagpt.prompts.ml_engineer import ASSIGN_TASK_TYPE_CONFIG, ASSIGN_TASK_TYPE_PROMPT
from metagpt.schema import Message, Plan, Task
from metagpt.tools import TOOL_TYPE_MAPPINGS
from metagpt.tools import TOOL_REGISTRY
from metagpt.utils.common import CodeParser, create_func_config
@ -47,13 +47,16 @@ class WritePlan(Action):
List[Dict]: tasks with task type assigned
"""
task_list = "\n".join([f"Task {task['task_id']}: {task['instruction']}" for task in tasks])
task_type_desc = "\n".join([f"- **{item.name}**: {item.desc}" for item in TOOL_TYPE_MAPPINGS.values()])
task_type_desc = "\n".join(
[f"- **{tool_type.name}**: {tool_type.desc}" for tool_type in TOOL_REGISTRY.get_tool_types().values()]
) # task type are binded with tool type now, should be improved in the future
prompt = ASSIGN_TASK_TYPE_PROMPT.format(
task_list=task_list, task_type_desc=task_type_desc
) # task types are set to be the same as tool types, for now
tool_config = create_func_config(ASSIGN_TASK_TYPE_CONFIG)
rsp = await self.llm.aask_code(prompt, **tool_config)
task_type_list = rsp["task_type"]
print(f"assigned task types: {task_type_list}")
for task, task_type in zip(tasks, task_type_list):
task["task_type"] = task_type
return json.dumps(tasks)