From 67464c37f11d854da43db99ad82a95bb81018962 Mon Sep 17 00:00:00 2001 From: yzlin Date: Mon, 25 Mar 2024 16:09:31 +0800 Subject: [PATCH] use ast to parse code texts and register a whole repo as tools --- metagpt/tools/tool_convert.py | 80 +++++++++++++++++++++++++++++++++- metagpt/tools/tool_registry.py | 76 +++++++++++++------------------- 2 files changed, 110 insertions(+), 46 deletions(-) diff --git a/metagpt/tools/tool_convert.py b/metagpt/tools/tool_convert.py index d366bef41..529f5ec14 100644 --- a/metagpt/tools/tool_convert.py +++ b/metagpt/tools/tool_convert.py @@ -1,3 +1,4 @@ +import ast import inspect from metagpt.utils.parse_docstring import GoogleDocstringParser, remove_spaces @@ -5,7 +6,8 @@ from metagpt.utils.parse_docstring import GoogleDocstringParser, remove_spaces PARSER = GoogleDocstringParser -def convert_code_to_tool_schema(obj, include: list[str] = None): +def convert_code_to_tool_schema(obj, include: list[str] = None) -> dict: + """Converts an object (function or class) to a tool schema by inspecting the object""" docstring = inspect.getdoc(obj) # assert docstring, "no docstring found for the objects, skip registering" @@ -27,6 +29,23 @@ def convert_code_to_tool_schema(obj, include: list[str] = None): return schema +def convert_code_to_tool_schema_ast(code: str) -> list[dict]: + """Converts a code string to a list of tool schemas by parsing the code with AST""" + + # Modify the AST nodes to include parent references, enabling to attach methods to their class + def add_parent_references(node, parent=None): + for child in ast.iter_child_nodes(node): + child.parent = parent + add_parent_references(child, parent=node) + + visitor = CodeVisitor() + parsed_code = ast.parse(code) + add_parent_references(parsed_code) + visitor.visit(parsed_code) + + return visitor.get_tool_schemas() + + def function_docstring_to_schema(fn_obj, docstring) -> dict: """ Converts a function's docstring into a schema dictionary. @@ -62,3 +81,62 @@ def get_class_method_docstring(cls, method_name): if method.__doc__: return method.__doc__ return None # No docstring found in the class hierarchy + + +class CodeVisitor(ast.NodeVisitor): + """Visit and convert the AST nodes within a code file to tool schemas""" + + def __init__(self): + self.tool_schemas = {} # {tool_name: tool_schema} + + def visit_ClassDef(self, node): + class_schemas = {"type": "class", "description": remove_spaces(ast.get_docstring(node)), "methods": {}} + for body_node in node.body: + if isinstance(body_node, (ast.FunctionDef, ast.AsyncFunctionDef)) and ( + not body_node.name.startswith("_") or body_node.name == "__init__" + ): + func_schemas = self._get_function_schemas(body_node) + class_schemas["methods"].update({body_node.name: func_schemas}) + self.tool_schemas[node.name] = class_schemas + + def visit_FunctionDef(self, node): + if isinstance(node.parent, ast.ClassDef) or node.name.startswith("_"): + return + self.tool_schemas[node.name] = self._get_function_schemas(node) + + def visit_AsyncFunctionDef(self, node): + if isinstance(node.parent, ast.ClassDef) or node.name.startswith("_"): + return + self.tool_schemas[node.name] = self._get_function_schemas(node) + + def _get_function_schemas(self, node): + docstring = remove_spaces(ast.get_docstring(node)) + overall_desc, param_desc = PARSER.parse(docstring) + return { + "type": "async_function" if isinstance(node, ast.AsyncFunctionDef) else "function", + "description": overall_desc, + "signature": self._get_function_signature(node), + "parameters": param_desc, + } + + def _get_function_signature(self, node): + args = [] + defaults = dict(zip([arg.arg for arg in node.args.args][-len(node.args.defaults) :], node.args.defaults)) + for arg in node.args.args: + arg_str = arg.arg + if arg.annotation: + annotation = ast.unparse(arg.annotation) + arg_str += f": {annotation}" + if arg.arg in defaults: + default_value = ast.unparse(defaults[arg.arg]) + arg_str += f" = {default_value}" + args.append(arg_str) + + return_annotation = "" + if node.returns: + return_annotation = f" -> {ast.unparse(node.returns)}" + + return f"({' ,'.join(args)}){return_annotation}" + + def get_tool_schemas(self): + return self.tool_schemas diff --git a/metagpt/tools/tool_registry.py b/metagpt/tools/tool_registry.py index e3d270b79..2fc44a2e8 100644 --- a/metagpt/tools/tool_registry.py +++ b/metagpt/tools/tool_registry.py @@ -7,17 +7,20 @@ """ from __future__ import annotations -import importlib.util import inspect import os from collections import defaultdict +from pathlib import Path import yaml from pydantic import BaseModel from metagpt.const import TOOL_SCHEMA_PATH from metagpt.logs import logger -from metagpt.tools.tool_convert import convert_code_to_tool_schema +from metagpt.tools.tool_convert import ( + convert_code_to_tool_schema, + convert_code_to_tool_schema_ast, +) from metagpt.tools.tool_data_type import Tool, ToolSchema @@ -27,21 +30,23 @@ class ToolRegistry(BaseModel): def register_tool( self, - tool_name, - tool_path, - schema_path="", - tool_code="", - tags=None, - tool_source_object=None, - include_functions=None, - verbose=False, + tool_name: str, + tool_path: str, + schemas: dict = None, + schema_path: str = "", + tool_code: str = "", + tags: list[str] = None, + tool_source_object=None, # can be any classes or functions + include_functions: list[str] = None, + verbose: bool = False, ): if self.has_tool(tool_name): return schema_path = schema_path or TOOL_SCHEMA_PATH / f"{tool_name}.yml" - schemas = make_schema(tool_source_object, include_functions, schema_path) + if not schemas: + schemas = make_schema(tool_source_object, include_functions, schema_path) if not schemas: return @@ -117,9 +122,6 @@ def make_schema(tool_source_object, include, path): schema = convert_code_to_tool_schema(tool_source_object, include=include) with open(path, "w", encoding="utf-8") as f: yaml.dump(schema, f, sort_keys=False) - # import json - # with open(str(path).replace("yml", "json"), "w", encoding="utf-8") as f: - # json.dump(schema, f, ensure_ascii=False, indent=4) except Exception as e: schema = {} logger.error(f"Fail to make schema: {e}") @@ -144,46 +146,30 @@ def validate_tool_names(tools: list[str]) -> dict[str, Tool]: return valid_tools -def load_module_from_file(filepath): - module_name = os.path.splitext(os.path.basename(filepath))[0] - spec = importlib.util.spec_from_file_location(module_name, filepath) - module = importlib.util.module_from_spec(spec) - spec.loader.exec_module(module) - return module - - def register_tools_from_file(file_path) -> dict[str, Tool]: + file_name = Path(file_path).name + if not file_name.endswith(".py") or file_name == "setup.py" or file_name.startswith("test"): + return {} registered_tools = {} - module = load_module_from_file(file_path) - for name, obj in inspect.getmembers(module): - if inspect.isclass(obj) or inspect.isfunction(obj): - if obj.__module__ == module.__name__: - # excluding imported classes and functions, register only those defined in the file - if "metagpt" in file_path: - # split to handle ../metagpt/metagpt/tools/... where only metapgt/tools/... is needed - file_path = "metagpt" + file_path.split("metagpt")[-1] - - TOOL_REGISTRY.register_tool( - tool_name=name, - tool_path=file_path, - tool_code="", # inspect.getsource(obj) will resulted in TypeError, skip it for now - tool_source_object=obj, - ) - registered_tools.update({name: TOOL_REGISTRY.get_tool(name)}) - + code = Path(file_path).read_text(encoding="utf-8") + tool_schemas = convert_code_to_tool_schema_ast(code) + for name, schemas in tool_schemas.items(): + TOOL_REGISTRY.register_tool( + tool_name=name, + tool_path=file_path, + schemas=schemas, + ) + registered_tools.update({name: TOOL_REGISTRY.get_tool(name)}) return registered_tools def register_tools_from_path(path) -> dict[str, Tool]: tools_registered = {} - if os.path.isfile(path) and path.endswith(".py"): - # Path is a Python file + if os.path.isfile(path): tools_registered.update(register_tools_from_file(path)) elif os.path.isdir(path): - # Path is a directory for root, _, files in os.walk(path): for file in files: - if file.endswith(".py"): - file_path = os.path.join(root, file) - tools_registered.update(register_tools_from_file(file_path)) + file_path = os.path.join(root, file) + tools_registered.update(register_tools_from_file(file_path)) return tools_registered