mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-04-27 01:36:29 +02:00
Merge pull request #1116 from garylin2099/di_fixes
Register tools from a path
This commit is contained in:
commit
02e0eb0add
5 changed files with 212 additions and 52 deletions
|
|
@ -1,3 +1,4 @@
|
|||
import ast
|
||||
import inspect
|
||||
|
||||
from metagpt.utils.parse_docstring import GoogleDocstringParser, remove_spaces
|
||||
|
|
@ -5,7 +6,8 @@ from metagpt.utils.parse_docstring import GoogleDocstringParser, remove_spaces
|
|||
PARSER = GoogleDocstringParser
|
||||
|
||||
|
||||
def convert_code_to_tool_schema(obj, include: list[str] = None):
|
||||
def convert_code_to_tool_schema(obj, include: list[str] = None) -> dict:
|
||||
"""Converts an object (function or class) to a tool schema by inspecting the object"""
|
||||
docstring = inspect.getdoc(obj)
|
||||
# assert docstring, "no docstring found for the objects, skip registering"
|
||||
|
||||
|
|
@ -27,6 +29,23 @@ def convert_code_to_tool_schema(obj, include: list[str] = None):
|
|||
return schema
|
||||
|
||||
|
||||
def convert_code_to_tool_schema_ast(code: str) -> list[dict]:
|
||||
"""Converts a code string to a list of tool schemas by parsing the code with AST"""
|
||||
|
||||
# Modify the AST nodes to include parent references, enabling to attach methods to their class
|
||||
def add_parent_references(node, parent=None):
|
||||
for child in ast.iter_child_nodes(node):
|
||||
child.parent = parent
|
||||
add_parent_references(child, parent=node)
|
||||
|
||||
visitor = CodeVisitor(code)
|
||||
parsed_code = ast.parse(code)
|
||||
add_parent_references(parsed_code)
|
||||
visitor.visit(parsed_code)
|
||||
|
||||
return visitor.get_tool_schemas()
|
||||
|
||||
|
||||
def function_docstring_to_schema(fn_obj, docstring) -> dict:
|
||||
"""
|
||||
Converts a function's docstring into a schema dictionary.
|
||||
|
|
@ -62,3 +81,67 @@ def get_class_method_docstring(cls, method_name):
|
|||
if method.__doc__:
|
||||
return method.__doc__
|
||||
return None # No docstring found in the class hierarchy
|
||||
|
||||
|
||||
class CodeVisitor(ast.NodeVisitor):
|
||||
"""Visit and convert the AST nodes within a code file to tool schemas"""
|
||||
|
||||
def __init__(self, source_code: str):
|
||||
self.tool_schemas = {} # {tool_name: tool_schema}
|
||||
self.source_code = source_code
|
||||
|
||||
def visit_ClassDef(self, node):
|
||||
class_schemas = {"type": "class", "description": remove_spaces(ast.get_docstring(node)), "methods": {}}
|
||||
for body_node in node.body:
|
||||
if isinstance(body_node, (ast.FunctionDef, ast.AsyncFunctionDef)) and (
|
||||
not body_node.name.startswith("_") or body_node.name == "__init__"
|
||||
):
|
||||
func_schemas = self._get_function_schemas(body_node)
|
||||
class_schemas["methods"].update({body_node.name: func_schemas})
|
||||
class_schemas["code"] = ast.get_source_segment(self.source_code, node)
|
||||
self.tool_schemas[node.name] = class_schemas
|
||||
|
||||
def visit_FunctionDef(self, node):
|
||||
self._visit_function(node)
|
||||
|
||||
def visit_AsyncFunctionDef(self, node):
|
||||
self._visit_function(node)
|
||||
|
||||
def _visit_function(self, node):
|
||||
if isinstance(node.parent, ast.ClassDef) or node.name.startswith("_"):
|
||||
return
|
||||
function_schemas = self._get_function_schemas(node)
|
||||
function_schemas["code"] = ast.get_source_segment(self.source_code, node)
|
||||
self.tool_schemas[node.name] = function_schemas
|
||||
|
||||
def _get_function_schemas(self, node):
|
||||
docstring = remove_spaces(ast.get_docstring(node))
|
||||
overall_desc, param_desc = PARSER.parse(docstring)
|
||||
return {
|
||||
"type": "async_function" if isinstance(node, ast.AsyncFunctionDef) else "function",
|
||||
"description": overall_desc,
|
||||
"signature": self._get_function_signature(node),
|
||||
"parameters": param_desc,
|
||||
}
|
||||
|
||||
def _get_function_signature(self, node):
|
||||
args = []
|
||||
defaults = dict(zip([arg.arg for arg in node.args.args][-len(node.args.defaults) :], node.args.defaults))
|
||||
for arg in node.args.args:
|
||||
arg_str = arg.arg
|
||||
if arg.annotation:
|
||||
annotation = ast.unparse(arg.annotation)
|
||||
arg_str += f": {annotation}"
|
||||
if arg.arg in defaults:
|
||||
default_value = ast.unparse(defaults[arg.arg])
|
||||
arg_str += f" = {default_value}"
|
||||
args.append(arg_str)
|
||||
|
||||
return_annotation = ""
|
||||
if node.returns:
|
||||
return_annotation = f" -> {ast.unparse(node.returns)}"
|
||||
|
||||
return f"({', '.join(args)}){return_annotation}"
|
||||
|
||||
def get_tool_schemas(self):
|
||||
return self.tool_schemas
|
||||
|
|
|
|||
|
|
@ -3,7 +3,6 @@ from __future__ import annotations
|
|||
import json
|
||||
from typing import Any
|
||||
|
||||
import jieba
|
||||
import numpy as np
|
||||
from pydantic import BaseModel, field_validator
|
||||
from rank_bm25 import BM25Okapi
|
||||
|
|
@ -182,7 +181,7 @@ class BM25ToolRecommender(ToolRecommender):
|
|||
self.bm25 = BM25Okapi(tokenized_corpus)
|
||||
|
||||
def _tokenize(self, text):
|
||||
return jieba.lcut(text) # FIXME: needs more sophisticated tokenization
|
||||
return text.split() # FIXME: needs more sophisticated tokenization
|
||||
|
||||
async def recall_tools(self, context: str = "", plan: Plan = None, topk: int = 20) -> list[Tool]:
|
||||
query = plan.current_task.instruction if plan else context
|
||||
|
|
@ -193,7 +192,7 @@ class BM25ToolRecommender(ToolRecommender):
|
|||
recalled_tools = [list(self.tools.values())[index] for index in top_indexes]
|
||||
|
||||
logger.info(
|
||||
f"Recalled tools: \n{[tool.name for tool in recalled_tools]}; Scores: {[doc_scores[index] for index in top_indexes]}"
|
||||
f"Recalled tools: \n{[tool.name for tool in recalled_tools]}; Scores: {[np.round(doc_scores[index], 4) for index in top_indexes]}"
|
||||
)
|
||||
|
||||
return recalled_tools
|
||||
|
|
|
|||
|
|
@ -7,17 +7,20 @@
|
|||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib.util
|
||||
import inspect
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
|
||||
import yaml
|
||||
from pydantic import BaseModel
|
||||
|
||||
from metagpt.const import TOOL_SCHEMA_PATH
|
||||
from metagpt.logs import logger
|
||||
from metagpt.tools.tool_convert import convert_code_to_tool_schema
|
||||
from metagpt.tools.tool_convert import (
|
||||
convert_code_to_tool_schema,
|
||||
convert_code_to_tool_schema_ast,
|
||||
)
|
||||
from metagpt.tools.tool_data_type import Tool, ToolSchema
|
||||
|
||||
|
||||
|
|
@ -27,21 +30,23 @@ class ToolRegistry(BaseModel):
|
|||
|
||||
def register_tool(
|
||||
self,
|
||||
tool_name,
|
||||
tool_path,
|
||||
schema_path="",
|
||||
tool_code="",
|
||||
tags=None,
|
||||
tool_source_object=None,
|
||||
include_functions=None,
|
||||
verbose=False,
|
||||
tool_name: str,
|
||||
tool_path: str,
|
||||
schemas: dict = None,
|
||||
schema_path: str = "",
|
||||
tool_code: str = "",
|
||||
tags: list[str] = None,
|
||||
tool_source_object=None, # can be any classes or functions
|
||||
include_functions: list[str] = None,
|
||||
verbose: bool = False,
|
||||
):
|
||||
if self.has_tool(tool_name):
|
||||
return
|
||||
|
||||
schema_path = schema_path or TOOL_SCHEMA_PATH / f"{tool_name}.yml"
|
||||
|
||||
schemas = make_schema(tool_source_object, include_functions, schema_path)
|
||||
if not schemas:
|
||||
schemas = make_schema(tool_source_object, include_functions, schema_path)
|
||||
|
||||
if not schemas:
|
||||
return
|
||||
|
|
@ -117,9 +122,6 @@ def make_schema(tool_source_object, include, path):
|
|||
schema = convert_code_to_tool_schema(tool_source_object, include=include)
|
||||
with open(path, "w", encoding="utf-8") as f:
|
||||
yaml.dump(schema, f, sort_keys=False)
|
||||
# import json
|
||||
# with open(str(path).replace("yml", "json"), "w", encoding="utf-8") as f:
|
||||
# json.dump(schema, f, ensure_ascii=False, indent=4)
|
||||
except Exception as e:
|
||||
schema = {}
|
||||
logger.error(f"Fail to make schema: {e}")
|
||||
|
|
@ -144,46 +146,32 @@ def validate_tool_names(tools: list[str]) -> dict[str, Tool]:
|
|||
return valid_tools
|
||||
|
||||
|
||||
def load_module_from_file(filepath):
|
||||
module_name = os.path.splitext(os.path.basename(filepath))[0]
|
||||
spec = importlib.util.spec_from_file_location(module_name, filepath)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(module)
|
||||
return module
|
||||
|
||||
|
||||
def register_tools_from_file(file_path) -> dict[str, Tool]:
|
||||
file_name = Path(file_path).name
|
||||
if not file_name.endswith(".py") or file_name == "setup.py" or file_name.startswith("test"):
|
||||
return {}
|
||||
registered_tools = {}
|
||||
module = load_module_from_file(file_path)
|
||||
for name, obj in inspect.getmembers(module):
|
||||
if inspect.isclass(obj) or inspect.isfunction(obj):
|
||||
if obj.__module__ == module.__name__:
|
||||
# excluding imported classes and functions, register only those defined in the file
|
||||
if "metagpt" in file_path:
|
||||
# split to handle ../metagpt/metagpt/tools/... where only metapgt/tools/... is needed
|
||||
file_path = "metagpt" + file_path.split("metagpt")[-1]
|
||||
|
||||
TOOL_REGISTRY.register_tool(
|
||||
tool_name=name,
|
||||
tool_path=file_path,
|
||||
tool_code="", # inspect.getsource(obj) will resulted in TypeError, skip it for now
|
||||
tool_source_object=obj,
|
||||
)
|
||||
registered_tools.update({name: TOOL_REGISTRY.get_tool(name)})
|
||||
|
||||
code = Path(file_path).read_text(encoding="utf-8")
|
||||
tool_schemas = convert_code_to_tool_schema_ast(code)
|
||||
for name, schemas in tool_schemas.items():
|
||||
tool_code = schemas.pop("code", "")
|
||||
TOOL_REGISTRY.register_tool(
|
||||
tool_name=name,
|
||||
tool_path=file_path,
|
||||
schemas=schemas,
|
||||
tool_code=tool_code,
|
||||
)
|
||||
registered_tools.update({name: TOOL_REGISTRY.get_tool(name)})
|
||||
return registered_tools
|
||||
|
||||
|
||||
def register_tools_from_path(path) -> dict[str, Tool]:
|
||||
tools_registered = {}
|
||||
if os.path.isfile(path) and path.endswith(".py"):
|
||||
# Path is a Python file
|
||||
if os.path.isfile(path):
|
||||
tools_registered.update(register_tools_from_file(path))
|
||||
elif os.path.isdir(path):
|
||||
# Path is a directory
|
||||
for root, _, files in os.walk(path):
|
||||
for file in files:
|
||||
if file.endswith(".py"):
|
||||
file_path = os.path.join(root, file)
|
||||
tools_registered.update(register_tools_from_file(file_path))
|
||||
file_path = os.path.join(root, file)
|
||||
tools_registered.update(register_tools_from_file(file_path))
|
||||
return tools_registered
|
||||
|
|
|
|||
|
|
@ -71,5 +71,4 @@ Pillow
|
|||
imap_tools==1.5.0 # Used by metagpt/tools/libs/email_login.py
|
||||
qianfan==0.3.2
|
||||
dashscope==1.14.1
|
||||
rank-bm25==0.2.2 # for tool recommendation
|
||||
jieba==0.42.1 # for tool recommendation
|
||||
rank-bm25==0.2.2 # for tool recommendation
|
||||
|
|
@ -2,7 +2,10 @@ from typing import Literal, Union
|
|||
|
||||
import pandas as pd
|
||||
|
||||
from metagpt.tools.tool_convert import convert_code_to_tool_schema
|
||||
from metagpt.tools.tool_convert import (
|
||||
convert_code_to_tool_schema,
|
||||
convert_code_to_tool_schema_ast,
|
||||
)
|
||||
|
||||
|
||||
class DummyClass:
|
||||
|
|
@ -128,3 +131,91 @@ def test_convert_code_to_tool_schema_function():
|
|||
def test_convert_code_to_tool_schema_async_function():
|
||||
schema = convert_code_to_tool_schema(dummy_async_fn)
|
||||
assert schema.get("type") == "async_function"
|
||||
|
||||
|
||||
TEST_CODE_FILE_TEXT = '''
|
||||
import pandas as pd # imported obj should not be parsed
|
||||
from some_module1 import some_imported_function, SomeImportedClass # imported obj should not be parsed
|
||||
from ..some_module2 import some_imported_function2 # relative import should not result in an error
|
||||
|
||||
class MyClass:
|
||||
"""This is a MyClass docstring."""
|
||||
def __init__(self, arg1):
|
||||
"""This is the constructor docstring."""
|
||||
self.arg1 = arg1
|
||||
|
||||
def my_method(self, arg2: Union[list[str], str], arg3: pd.DataFrame, arg4: int = 1, arg5: Literal["a","b","c"] = "a") -> Tuple[int, str]:
|
||||
"""
|
||||
This is a method docstring.
|
||||
|
||||
Args:
|
||||
arg2 (Union[list[str], str]): A union of a list of strings and a string.
|
||||
...
|
||||
|
||||
Returns:
|
||||
Tuple[int, str]: A tuple of an integer and a string.
|
||||
"""
|
||||
return self.arg4 + arg5
|
||||
|
||||
async def my_async_method(self, some_arg) -> str:
|
||||
return "hi"
|
||||
|
||||
def _private_method(self): # private should not be parsed
|
||||
return "private"
|
||||
|
||||
def my_function(arg1, arg2) -> dict:
|
||||
"""This is a function docstring."""
|
||||
return arg1 + arg2
|
||||
|
||||
def my_async_function(arg1, arg2) -> dict:
|
||||
return arg1 + arg2
|
||||
|
||||
def _private_function(): # private should not be parsed
|
||||
return "private"
|
||||
'''
|
||||
|
||||
|
||||
def test_convert_code_to_tool_schema_ast():
|
||||
expected = {
|
||||
"MyClass": {
|
||||
"type": "class",
|
||||
"description": "This is a MyClass docstring.",
|
||||
"methods": {
|
||||
"__init__": {
|
||||
"type": "function",
|
||||
"description": "This is the constructor docstring.",
|
||||
"signature": "(self, arg1)",
|
||||
"parameters": "",
|
||||
},
|
||||
"my_method": {
|
||||
"type": "function",
|
||||
"description": "This is a method docstring. ",
|
||||
"signature": "(self, arg2: Union[list[str], str], arg3: pd.DataFrame, arg4: int = 1, arg5: Literal['a', 'b', 'c'] = 'a') -> Tuple[int, str]",
|
||||
"parameters": "Args: arg2 (Union[list[str], str]): A union of a list of strings and a string. ... Returns: Tuple[int, str]: A tuple of an integer and a string.",
|
||||
},
|
||||
"my_async_method": {
|
||||
"type": "async_function",
|
||||
"description": "",
|
||||
"signature": "(self, some_arg) -> str",
|
||||
"parameters": "",
|
||||
},
|
||||
},
|
||||
"code": 'class MyClass:\n """This is a MyClass docstring."""\n def __init__(self, arg1):\n """This is the constructor docstring."""\n self.arg1 = arg1\n\n def my_method(self, arg2: Union[list[str], str], arg3: pd.DataFrame, arg4: int = 1, arg5: Literal["a","b","c"] = "a") -> Tuple[int, str]:\n """\n This is a method docstring.\n \n Args:\n arg2 (Union[list[str], str]): A union of a list of strings and a string.\n ...\n \n Returns:\n Tuple[int, str]: A tuple of an integer and a string.\n """\n return self.arg4 + arg5\n \n async def my_async_method(self, some_arg) -> str:\n return "hi"\n \n def _private_method(self): # private should not be parsed\n return "private"',
|
||||
},
|
||||
"my_function": {
|
||||
"type": "function",
|
||||
"description": "This is a function docstring.",
|
||||
"signature": "(arg1, arg2) -> dict",
|
||||
"parameters": "",
|
||||
"code": 'def my_function(arg1, arg2) -> dict:\n """This is a function docstring."""\n return arg1 + arg2',
|
||||
},
|
||||
"my_async_function": {
|
||||
"type": "function",
|
||||
"description": "",
|
||||
"signature": "(arg1, arg2) -> dict",
|
||||
"parameters": "",
|
||||
"code": "def my_async_function(arg1, arg2) -> dict:\n return arg1 + arg2",
|
||||
},
|
||||
}
|
||||
schemas = convert_code_to_tool_schema_ast(TEST_CODE_FILE_TEXT)
|
||||
assert schemas == expected
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue