From da98495657bd2d4244b41330d999a38866137f36 Mon Sep 17 00:00:00 2001 From: yzlin Date: Mon, 25 Mar 2024 17:35:12 +0800 Subject: [PATCH] add tool code for ast parsing --- metagpt/tools/tool_convert.py | 17 +++++++++++------ metagpt/tools/tool_registry.py | 2 ++ 2 files changed, 13 insertions(+), 6 deletions(-) diff --git a/metagpt/tools/tool_convert.py b/metagpt/tools/tool_convert.py index 529f5ec14..e6894762a 100644 --- a/metagpt/tools/tool_convert.py +++ b/metagpt/tools/tool_convert.py @@ -38,7 +38,7 @@ def convert_code_to_tool_schema_ast(code: str) -> list[dict]: child.parent = parent add_parent_references(child, parent=node) - visitor = CodeVisitor() + visitor = CodeVisitor(code) parsed_code = ast.parse(code) add_parent_references(parsed_code) visitor.visit(parsed_code) @@ -86,8 +86,9 @@ def get_class_method_docstring(cls, method_name): class CodeVisitor(ast.NodeVisitor): """Visit and convert the AST nodes within a code file to tool schemas""" - def __init__(self): + def __init__(self, source_code: str): self.tool_schemas = {} # {tool_name: tool_schema} + self.source_code = source_code def visit_ClassDef(self, node): class_schemas = {"type": "class", "description": remove_spaces(ast.get_docstring(node)), "methods": {}} @@ -97,17 +98,21 @@ class CodeVisitor(ast.NodeVisitor): ): func_schemas = self._get_function_schemas(body_node) class_schemas["methods"].update({body_node.name: func_schemas}) + class_schemas["code"] = ast.get_source_segment(self.source_code, node) self.tool_schemas[node.name] = class_schemas def visit_FunctionDef(self, node): - if isinstance(node.parent, ast.ClassDef) or node.name.startswith("_"): - return - self.tool_schemas[node.name] = self._get_function_schemas(node) + self._visit_function(node) def visit_AsyncFunctionDef(self, node): + self._visit_function(node) + + def _visit_function(self, node): if isinstance(node.parent, ast.ClassDef) or node.name.startswith("_"): return - self.tool_schemas[node.name] = self._get_function_schemas(node) + function_schemas = self._get_function_schemas(node) + function_schemas["code"] = ast.get_source_segment(self.source_code, node) + self.tool_schemas[node.name] = function_schemas def _get_function_schemas(self, node): docstring = remove_spaces(ast.get_docstring(node)) diff --git a/metagpt/tools/tool_registry.py b/metagpt/tools/tool_registry.py index 2fc44a2e8..50875e235 100644 --- a/metagpt/tools/tool_registry.py +++ b/metagpt/tools/tool_registry.py @@ -154,10 +154,12 @@ def register_tools_from_file(file_path) -> dict[str, Tool]: code = Path(file_path).read_text(encoding="utf-8") tool_schemas = convert_code_to_tool_schema_ast(code) for name, schemas in tool_schemas.items(): + tool_code = schemas.pop("code", "") TOOL_REGISTRY.register_tool( tool_name=name, tool_path=file_path, schemas=schemas, + tool_code=tool_code, ) registered_tools.update({name: TOOL_REGISTRY.get_tool(name)}) return registered_tools