Merge branch 'di_mgx' into 'mgx_ops'

DI features: register tools from path

See merge request pub/MetaGPT!8
This commit is contained in:
林义章 2024-03-30 08:55:34 +00:00
commit 7fc5cc3804
7 changed files with 231 additions and 30 deletions

View file

@ -1,7 +1,7 @@
from __future__ import annotations
import json
from typing import Literal, Union
from typing import Literal
from pydantic import Field, model_validator
@ -39,7 +39,7 @@ class DataInterpreter(Role):
use_plan: bool = True
use_reflection: bool = False
execute_code: ExecuteNbCode = Field(default_factory=ExecuteNbCode, exclude=True)
tools: Union[str, list[str]] = [] # Use special symbol ["<all>"] to indicate use of all registered tools
tools: list[str] = [] # Use special symbol ["<all>"] to indicate use of all registered tools
tool_recommender: ToolRecommender = None
react_mode: Literal["plan_and_act", "react"] = "plan_and_act"
max_react_loop: int = 10 # used for react mode
@ -50,7 +50,7 @@ class DataInterpreter(Role):
self.use_plan = (
self.react_mode == "plan_and_act"
) # create a flag for convenience, overwrite any passed-in value
if self.tools:
if self.tools and not self.tool_recommender:
self.tool_recommender = BM25ToolRecommender(tools=self.tools)
self.set_actions([WriteAnalysisCode])
self._set_state(0)
@ -104,7 +104,7 @@ class DataInterpreter(Role):
plan_status = self.planner.get_plan_status() if self.use_plan else ""
# tool info
if self.tools:
if self.tool_recommender:
context = (
self.working_memory.get()[-1].content if self.working_memory.get() else ""
) # thoughts from _think stage in 'react' mode

View file

@ -1,3 +1,4 @@
import ast
import inspect
from metagpt.utils.parse_docstring import GoogleDocstringParser, remove_spaces
@ -5,9 +6,10 @@ from metagpt.utils.parse_docstring import GoogleDocstringParser, remove_spaces
PARSER = GoogleDocstringParser
def convert_code_to_tool_schema(obj, include: list[str] = None):
def convert_code_to_tool_schema(obj, include: list[str] = None) -> dict:
"""Converts an object (function or class) to a tool schema by inspecting the object"""
docstring = inspect.getdoc(obj)
assert docstring, "no docstring found for the objects, skip registering"
# assert docstring, "no docstring found for the objects, skip registering"
if inspect.isclass(obj):
schema = {"type": "class", "description": remove_spaces(docstring), "methods": {}}
@ -27,6 +29,16 @@ def convert_code_to_tool_schema(obj, include: list[str] = None):
return schema
def convert_code_to_tool_schema_ast(code: str) -> list[dict]:
"""Converts a code string to a list of tool schemas by parsing the code with AST"""
visitor = CodeVisitor(code)
parsed_code = ast.parse(code)
visitor.visit(parsed_code)
return visitor.get_tool_schemas()
def function_docstring_to_schema(fn_obj, docstring) -> dict:
"""
Converts a function's docstring into a schema dictionary.
@ -62,3 +74,67 @@ def get_class_method_docstring(cls, method_name):
if method.__doc__:
return method.__doc__
return None # No docstring found in the class hierarchy
class CodeVisitor(ast.NodeVisitor):
"""Visit and convert the AST nodes within a code file to tool schemas"""
def __init__(self, source_code: str):
self.tool_schemas = {} # {tool_name: tool_schema}
self.source_code = source_code
def visit_ClassDef(self, node):
class_schemas = {"type": "class", "description": remove_spaces(ast.get_docstring(node)), "methods": {}}
for body_node in node.body:
if isinstance(body_node, (ast.FunctionDef, ast.AsyncFunctionDef)) and (
not body_node.name.startswith("_") or body_node.name == "__init__"
):
func_schemas = self._get_function_schemas(body_node)
class_schemas["methods"].update({body_node.name: func_schemas})
class_schemas["code"] = ast.get_source_segment(self.source_code, node)
self.tool_schemas[node.name] = class_schemas
def visit_FunctionDef(self, node):
self._visit_function(node)
def visit_AsyncFunctionDef(self, node):
self._visit_function(node)
def _visit_function(self, node):
if node.name.startswith("_"):
return
function_schemas = self._get_function_schemas(node)
function_schemas["code"] = ast.get_source_segment(self.source_code, node)
self.tool_schemas[node.name] = function_schemas
def _get_function_schemas(self, node):
docstring = remove_spaces(ast.get_docstring(node))
overall_desc, param_desc = PARSER.parse(docstring)
return {
"type": "async_function" if isinstance(node, ast.AsyncFunctionDef) else "function",
"description": overall_desc,
"signature": self._get_function_signature(node),
"parameters": param_desc,
}
def _get_function_signature(self, node):
args = []
defaults = dict(zip([arg.arg for arg in node.args.args][-len(node.args.defaults) :], node.args.defaults))
for arg in node.args.args:
arg_str = arg.arg
if arg.annotation:
annotation = ast.unparse(arg.annotation)
arg_str += f": {annotation}"
if arg.arg in defaults:
default_value = ast.unparse(defaults[arg.arg])
arg_str += f" = {default_value}"
args.append(arg_str)
return_annotation = ""
if node.returns:
return_annotation = f" -> {ast.unparse(node.returns)}"
return f"({', '.join(args)}){return_annotation}"
def get_tool_schemas(self):
return self.tool_schemas

View file

@ -3,7 +3,6 @@ from __future__ import annotations
import json
from typing import Any
import jieba
import numpy as np
from pydantic import BaseModel, field_validator
from rank_bm25 import BM25Okapi
@ -182,7 +181,7 @@ class BM25ToolRecommender(ToolRecommender):
self.bm25 = BM25Okapi(tokenized_corpus)
def _tokenize(self, text):
return jieba.lcut(text) # FIXME: needs more sophisticated tokenization
return text.split() # FIXME: needs more sophisticated tokenization
async def recall_tools(self, context: str = "", plan: Plan = None, topk: int = 20) -> list[Tool]:
query = plan.current_task.instruction if plan else context
@ -193,7 +192,7 @@ class BM25ToolRecommender(ToolRecommender):
recalled_tools = [list(self.tools.values())[index] for index in top_indexes]
logger.info(
f"Recalled tools: \n{[tool.name for tool in recalled_tools]}; Scores: {[doc_scores[index] for index in top_indexes]}"
f"Recalled tools: \n{[tool.name for tool in recalled_tools]}; Scores: {[np.round(doc_scores[index], 4) for index in top_indexes]}"
)
return recalled_tools

View file

@ -10,14 +10,17 @@ from __future__ import annotations
import inspect
import os
from collections import defaultdict
from typing import Union
from pathlib import Path
import yaml
from pydantic import BaseModel
from metagpt.const import TOOL_SCHEMA_PATH
from metagpt.logs import logger
from metagpt.tools.tool_convert import convert_code_to_tool_schema
from metagpt.tools.tool_convert import (
convert_code_to_tool_schema,
convert_code_to_tool_schema_ast,
)
from metagpt.tools.tool_data_type import Tool, ToolSchema
@ -27,21 +30,23 @@ class ToolRegistry(BaseModel):
def register_tool(
self,
tool_name,
tool_path,
schema_path="",
tool_code="",
tags=None,
tool_source_object=None,
include_functions=None,
verbose=False,
tool_name: str,
tool_path: str,
schemas: dict = None,
schema_path: str = "",
tool_code: str = "",
tags: list[str] = None,
tool_source_object=None, # can be any classes or functions
include_functions: list[str] = None,
verbose: bool = False,
):
if self.has_tool(tool_name):
return
schema_path = schema_path or TOOL_SCHEMA_PATH / f"{tool_name}.yml"
schemas = make_schema(tool_source_object, include_functions, schema_path)
if not schemas:
schemas = make_schema(tool_source_object, include_functions, schema_path)
if not schemas:
return
@ -117,9 +122,6 @@ def make_schema(tool_source_object, include, path):
schema = convert_code_to_tool_schema(tool_source_object, include=include)
with open(path, "w", encoding="utf-8") as f:
yaml.dump(schema, f, sort_keys=False)
# import json
# with open(str(path).replace("yml", "json"), "w", encoding="utf-8") as f:
# json.dump(schema, f, ensure_ascii=False, indent=4)
except Exception as e:
schema = {}
logger.error(f"Fail to make schema: {e}")
@ -127,15 +129,49 @@ def make_schema(tool_source_object, include, path):
return schema
def validate_tool_names(tools: Union[list[str], str]) -> str:
def validate_tool_names(tools: list[str]) -> dict[str, Tool]:
assert isinstance(tools, list), "tools must be a list of str"
valid_tools = {}
for key in tools:
# one can define either tool names or tool type names, take union to get the whole set
if TOOL_REGISTRY.has_tool(key):
# one can define either tool names OR tool tags OR tool path, take union to get the whole set
# if tool paths are provided, they will be registered on the fly
if os.path.isdir(key) or os.path.isfile(key):
valid_tools.update(register_tools_from_path(key))
elif TOOL_REGISTRY.has_tool(key):
valid_tools.update({key: TOOL_REGISTRY.get_tool(key)})
elif TOOL_REGISTRY.has_tool_tag(key):
valid_tools.update(TOOL_REGISTRY.get_tools_by_tag(key))
else:
logger.warning(f"invalid tool name or tool type name: {key}, skipped")
return valid_tools
def register_tools_from_file(file_path) -> dict[str, Tool]:
file_name = Path(file_path).name
if not file_name.endswith(".py") or file_name == "setup.py" or file_name.startswith("test"):
return {}
registered_tools = {}
code = Path(file_path).read_text(encoding="utf-8")
tool_schemas = convert_code_to_tool_schema_ast(code)
for name, schemas in tool_schemas.items():
tool_code = schemas.pop("code", "")
TOOL_REGISTRY.register_tool(
tool_name=name,
tool_path=file_path,
schemas=schemas,
tool_code=tool_code,
)
registered_tools.update({name: TOOL_REGISTRY.get_tool(name)})
return registered_tools
def register_tools_from_path(path) -> dict[str, Tool]:
tools_registered = {}
if os.path.isfile(path):
tools_registered.update(register_tools_from_file(path))
elif os.path.isdir(path):
for root, _, files in os.walk(path):
for file in files:
file_path = os.path.join(root, file)
tools_registered.update(register_tools_from_file(file_path))
return tools_registered

View file

@ -3,7 +3,7 @@ from typing import Tuple
def remove_spaces(text):
return re.sub(r"\s+", " ", text).strip()
return re.sub(r"\s+", " ", text).strip() if text else ""
class DocstringParser:

View file

@ -69,5 +69,4 @@ imap_tools==1.5.0 # Used by metagpt/tools/libs/email_login.py
qianfan==0.3.2
dashscope==1.14.1
rank-bm25==0.2.2 # for tool recommendation
jieba==0.42.1 # for tool recommendation
gymnasium==0.29.1
gymnasium==0.29.1

View file

@ -2,7 +2,10 @@ from typing import Literal, Union
import pandas as pd
from metagpt.tools.tool_convert import convert_code_to_tool_schema
from metagpt.tools.tool_convert import (
convert_code_to_tool_schema,
convert_code_to_tool_schema_ast,
)
class DummyClass:
@ -128,3 +131,91 @@ def test_convert_code_to_tool_schema_function():
def test_convert_code_to_tool_schema_async_function():
schema = convert_code_to_tool_schema(dummy_async_fn)
assert schema.get("type") == "async_function"
TEST_CODE_FILE_TEXT = '''
import pandas as pd # imported obj should not be parsed
from some_module1 import some_imported_function, SomeImportedClass # imported obj should not be parsed
from ..some_module2 import some_imported_function2 # relative import should not result in an error
class MyClass:
"""This is a MyClass docstring."""
def __init__(self, arg1):
"""This is the constructor docstring."""
self.arg1 = arg1
def my_method(self, arg2: Union[list[str], str], arg3: pd.DataFrame, arg4: int = 1, arg5: Literal["a","b","c"] = "a") -> Tuple[int, str]:
"""
This is a method docstring.
Args:
arg2 (Union[list[str], str]): A union of a list of strings and a string.
...
Returns:
Tuple[int, str]: A tuple of an integer and a string.
"""
return self.arg4 + arg5
async def my_async_method(self, some_arg) -> str:
return "hi"
def _private_method(self): # private should not be parsed
return "private"
def my_function(arg1, arg2) -> dict:
"""This is a function docstring."""
return arg1 + arg2
def my_async_function(arg1, arg2) -> dict:
return arg1 + arg2
def _private_function(): # private should not be parsed
return "private"
'''
def test_convert_code_to_tool_schema_ast():
expected = {
"MyClass": {
"type": "class",
"description": "This is a MyClass docstring.",
"methods": {
"__init__": {
"type": "function",
"description": "This is the constructor docstring.",
"signature": "(self, arg1)",
"parameters": "",
},
"my_method": {
"type": "function",
"description": "This is a method docstring. ",
"signature": "(self, arg2: Union[list[str], str], arg3: pd.DataFrame, arg4: int = 1, arg5: Literal['a', 'b', 'c'] = 'a') -> Tuple[int, str]",
"parameters": "Args: arg2 (Union[list[str], str]): A union of a list of strings and a string. ... Returns: Tuple[int, str]: A tuple of an integer and a string.",
},
"my_async_method": {
"type": "async_function",
"description": "",
"signature": "(self, some_arg) -> str",
"parameters": "",
},
},
"code": 'class MyClass:\n """This is a MyClass docstring."""\n def __init__(self, arg1):\n """This is the constructor docstring."""\n self.arg1 = arg1\n\n def my_method(self, arg2: Union[list[str], str], arg3: pd.DataFrame, arg4: int = 1, arg5: Literal["a","b","c"] = "a") -> Tuple[int, str]:\n """\n This is a method docstring.\n \n Args:\n arg2 (Union[list[str], str]): A union of a list of strings and a string.\n ...\n \n Returns:\n Tuple[int, str]: A tuple of an integer and a string.\n """\n return self.arg4 + arg5\n \n async def my_async_method(self, some_arg) -> str:\n return "hi"\n \n def _private_method(self): # private should not be parsed\n return "private"',
},
"my_function": {
"type": "function",
"description": "This is a function docstring.",
"signature": "(arg1, arg2) -> dict",
"parameters": "",
"code": 'def my_function(arg1, arg2) -> dict:\n """This is a function docstring."""\n return arg1 + arg2',
},
"my_async_function": {
"type": "function",
"description": "",
"signature": "(arg1, arg2) -> dict",
"parameters": "",
"code": "def my_async_function(arg1, arg2) -> dict:\n return arg1 + arg2",
},
}
schemas = convert_code_to_tool_schema_ast(TEST_CODE_FILE_TEXT)
assert schemas == expected