diff --git a/examples/di/use_github_repo.py b/examples/di/use_github_repo.py new file mode 100644 index 000000000..ad541d2d9 --- /dev/null +++ b/examples/di/use_github_repo.py @@ -0,0 +1,18 @@ +import asyncio + +from metagpt.roles.di.data_interpreter import DataInterpreter + +USE_GOT_REPO_REQ = """ +This is a link to the GOT github repo: https://github.com/spcl/graph-of-thoughts.git. +Clone it, read the README to understand the usage, install it, and finally run the quick start example. +**Note the config for LLM is at `config/config_got.json`, use this path directly.** Don't write all codes in one response, each time, just write code for one step. +""" + + +async def main(): + di = DataInterpreter(tools=["Terminal"]) + await di.run(USE_GOT_REPO_REQ) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/metagpt/actions/di/execute_nb_code.py b/metagpt/actions/di/execute_nb_code.py index 0cf16b70f..aab204499 100644 --- a/metagpt/actions/di/execute_nb_code.py +++ b/metagpt/actions/di/execute_nb_code.py @@ -106,7 +106,7 @@ class ExecuteNbCode(Action): else: cell["outputs"].append(new_output(output_type="stream", name="stdout", text=str(output))) - def parse_outputs(self, outputs: list[str], keep_len: int = 2000) -> Tuple[bool, str]: + def parse_outputs(self, outputs: list[str], keep_len: int = 5000) -> Tuple[bool, str]: """Parses the outputs received from notebook execution.""" assert isinstance(outputs, list) parsed_output, is_success = [], True diff --git a/metagpt/logs.py b/metagpt/logs.py index 90bac21aa..e134afca3 100644 --- a/metagpt/logs.py +++ b/metagpt/logs.py @@ -6,6 +6,8 @@ @File : logs.py """ +from __future__ import annotations + import sys from datetime import datetime from functools import partial @@ -34,9 +36,22 @@ def log_llm_stream(msg): _llm_stream_log(msg) +def log_tool_output(output: dict, tool_name: str = ""): + """interface for logging tool output, can be set to log tool output in different ways to different places with set_tool_output_logfunc""" + _tool_output_log(output) + + def set_llm_stream_logfunc(func): global _llm_stream_log _llm_stream_log = func +def set_tool_output_logfunc(func): + global _tool_output_log + _tool_output_log = func + + _llm_stream_log = partial(print, end="") + + +_tool_output_log = partial(print, end="") diff --git a/metagpt/prompts/di/write_analysis_code.py b/metagpt/prompts/di/write_analysis_code.py index e5663d498..d2b4f1299 100644 --- a/metagpt/prompts/di/write_analysis_code.py +++ b/metagpt/prompts/di/write_analysis_code.py @@ -1,4 +1,8 @@ -INTERPRETER_SYSTEM_MSG = """As a data scientist, you need to help user to achieve their goal step by step in a continuous Jupyter notebook. Since it is a notebook environment, don't use asyncio.run. Instead, use await if you need to call an async function.""" +INTERPRETER_SYSTEM_MSG = """ +As a data scientist, you need to help user to achieve their goal step by step in a continuous Jupyter notebook. +Since it is a notebook environment, don't use asyncio.run. Instead, use await if you need to call an async function. +If you want to use shell command such as git clone, pip install packages, navigate folders, read file, etc., use Terminal tool if available before trying ! in notebook block. +""" STRUCTUAL_PROMPT = """ # User Requirement diff --git a/metagpt/roles/di/data_interpreter.py b/metagpt/roles/di/data_interpreter.py index 35c6d1297..c78bc7e1a 100644 --- a/metagpt/roles/di/data_interpreter.py +++ b/metagpt/roles/di/data_interpreter.py @@ -1,7 +1,7 @@ from __future__ import annotations import json -from typing import Literal, Union +from typing import Literal from pydantic import Field, model_validator @@ -39,7 +39,7 @@ class DataInterpreter(Role): use_plan: bool = True use_reflection: bool = False execute_code: ExecuteNbCode = Field(default_factory=ExecuteNbCode, exclude=True) - tools: Union[str, list[str]] = [] # Use special symbol [""] to indicate use of all registered tools + tools: list[str] = [] # Use special symbol [""] to indicate use of all registered tools tool_recommender: ToolRecommender = None react_mode: Literal["plan_and_act", "react"] = "plan_and_act" max_react_loop: int = 10 # used for react mode @@ -50,7 +50,7 @@ class DataInterpreter(Role): self.use_plan = ( self.react_mode == "plan_and_act" ) # create a flag for convenience, overwrite any passed-in value - if self.tools: + if self.tools and not self.tool_recommender: self.tool_recommender = BM25ToolRecommender(tools=self.tools) self.set_actions([WriteAnalysisCode]) self._set_state(0) @@ -104,7 +104,7 @@ class DataInterpreter(Role): plan_status = self.planner.get_plan_status() if self.use_plan else "" # tool info - if self.tools: + if self.tool_recommender: context = ( self.working_memory.get()[-1].content if self.working_memory.get() else "" ) # thoughts from _think stage in 'react' mode diff --git a/metagpt/tools/libs/__init__.py b/metagpt/tools/libs/__init__.py index eb5ffbc5c..cd70d9811 100644 --- a/metagpt/tools/libs/__init__.py +++ b/metagpt/tools/libs/__init__.py @@ -11,6 +11,7 @@ from metagpt.tools.libs import ( gpt_v_generator, web_scraping, email_login, + terminal, ) from metagpt.tools.libs.software_development import ( write_prd, @@ -36,4 +37,5 @@ _ = ( run_qa_test, fix_bug, git_archive, + terminal, ) # Avoid pre-commit error diff --git a/metagpt/tools/libs/terminal.py b/metagpt/tools/libs/terminal.py new file mode 100644 index 000000000..2b1657bdd --- /dev/null +++ b/metagpt/tools/libs/terminal.py @@ -0,0 +1,62 @@ +import subprocess + +from metagpt.logs import log_tool_output +from metagpt.tools.tool_registry import register_tool + + +@register_tool() +class Terminal: + """A tool for running terminal commands. Don't initialize a new instance of this class if one already exists.""" + + def __init__(self): + self.shell_command = ["bash"] # FIXME: should consider windows support later + self.command_terminator = "\n" + self.end_marker = "#END_MARKER#" + + # Start a persistent shell process + self.process = subprocess.Popen( + self.shell_command, + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + text=True, + bufsize=1, # Line buffered + ) + + def run_command(self, cmd: str) -> str: + """ + Run a command in the terminal and return the output. + When the command is being executed, stream the output to the terminal. + Maintains state across commands, such as current directory. + + Args: + cmd (str): The command to run in the terminal. + + Returns: + str: The output of the terminal command. + """ + cmd_output = [] + + # Send the command + self.process.stdin.write(cmd + self.command_terminator) + self.process.stdin.write( + f'echo "{self.end_marker}"' + self.command_terminator + ) # Unique marker to signal command end + self.process.stdin.flush() + log_tool_output(output={"cmd": cmd + self.command_terminator}, tool_name="Terminal") # log the command + + # Read the output until the unique marker is found + while True: + line = self.process.stdout.readline() + if line.strip() == self.end_marker: + break + log_tool_output(output={"output": line}, tool_name="Terminal") # log stdout in real-time + cmd_output.append(line) + + return "".join(cmd_output) + + def close(self): + """Close the persistent shell process.""" + self.process.stdin.close() + self.process.terminate() + self.process.wait() diff --git a/metagpt/tools/tool_convert.py b/metagpt/tools/tool_convert.py index 42c65b9e7..829269b1b 100644 --- a/metagpt/tools/tool_convert.py +++ b/metagpt/tools/tool_convert.py @@ -1,3 +1,4 @@ +import ast import inspect from metagpt.utils.parse_docstring import GoogleDocstringParser, remove_spaces @@ -5,9 +6,10 @@ from metagpt.utils.parse_docstring import GoogleDocstringParser, remove_spaces PARSER = GoogleDocstringParser -def convert_code_to_tool_schema(obj, include: list[str] = None): +def convert_code_to_tool_schema(obj, include: list[str] = None) -> dict: + """Converts an object (function or class) to a tool schema by inspecting the object""" docstring = inspect.getdoc(obj) - assert docstring, "no docstring found for the objects, skip registering" + # assert docstring, "no docstring found for the objects, skip registering" if inspect.isclass(obj): schema = {"type": "class", "description": remove_spaces(docstring), "methods": {}} @@ -27,6 +29,16 @@ def convert_code_to_tool_schema(obj, include: list[str] = None): return schema +def convert_code_to_tool_schema_ast(code: str) -> list[dict]: + """Converts a code string to a list of tool schemas by parsing the code with AST""" + + visitor = CodeVisitor(code) + parsed_code = ast.parse(code) + visitor.visit(parsed_code) + + return visitor.get_tool_schemas() + + def function_docstring_to_schema(fn_obj, docstring) -> dict: """ Converts a function's docstring into a schema dictionary. @@ -62,3 +74,67 @@ def get_class_method_docstring(cls, method_name): if method.__doc__: return method.__doc__ return None # No docstring found in the class hierarchy + + +class CodeVisitor(ast.NodeVisitor): + """Visit and convert the AST nodes within a code file to tool schemas""" + + def __init__(self, source_code: str): + self.tool_schemas = {} # {tool_name: tool_schema} + self.source_code = source_code + + def visit_ClassDef(self, node): + class_schemas = {"type": "class", "description": remove_spaces(ast.get_docstring(node)), "methods": {}} + for body_node in node.body: + if isinstance(body_node, (ast.FunctionDef, ast.AsyncFunctionDef)) and ( + not body_node.name.startswith("_") or body_node.name == "__init__" + ): + func_schemas = self._get_function_schemas(body_node) + class_schemas["methods"].update({body_node.name: func_schemas}) + class_schemas["code"] = ast.get_source_segment(self.source_code, node) + self.tool_schemas[node.name] = class_schemas + + def visit_FunctionDef(self, node): + self._visit_function(node) + + def visit_AsyncFunctionDef(self, node): + self._visit_function(node) + + def _visit_function(self, node): + if node.name.startswith("_"): + return + function_schemas = self._get_function_schemas(node) + function_schemas["code"] = ast.get_source_segment(self.source_code, node) + self.tool_schemas[node.name] = function_schemas + + def _get_function_schemas(self, node): + docstring = remove_spaces(ast.get_docstring(node)) + overall_desc, param_desc = PARSER.parse(docstring) + return { + "type": "async_function" if isinstance(node, ast.AsyncFunctionDef) else "function", + "description": overall_desc, + "signature": self._get_function_signature(node), + "parameters": param_desc, + } + + def _get_function_signature(self, node): + args = [] + defaults = dict(zip([arg.arg for arg in node.args.args][-len(node.args.defaults) :], node.args.defaults)) + for arg in node.args.args: + arg_str = arg.arg + if arg.annotation: + annotation = ast.unparse(arg.annotation) + arg_str += f": {annotation}" + if arg.arg in defaults: + default_value = ast.unparse(defaults[arg.arg]) + arg_str += f" = {default_value}" + args.append(arg_str) + + return_annotation = "" + if node.returns: + return_annotation = f" -> {ast.unparse(node.returns)}" + + return f"({', '.join(args)}){return_annotation}" + + def get_tool_schemas(self): + return self.tool_schemas diff --git a/metagpt/tools/tool_recommend.py b/metagpt/tools/tool_recommend.py index c4e324a0b..01ff61834 100644 --- a/metagpt/tools/tool_recommend.py +++ b/metagpt/tools/tool_recommend.py @@ -3,7 +3,6 @@ from __future__ import annotations import json from typing import Any -import jieba import numpy as np from pydantic import BaseModel, field_validator from rank_bm25 import BM25Okapi @@ -182,7 +181,7 @@ class BM25ToolRecommender(ToolRecommender): self.bm25 = BM25Okapi(tokenized_corpus) def _tokenize(self, text): - return jieba.lcut(text) # FIXME: needs more sophisticated tokenization + return text.split() # FIXME: needs more sophisticated tokenization async def recall_tools(self, context: str = "", plan: Plan = None, topk: int = 20) -> list[Tool]: query = plan.current_task.instruction if plan else context @@ -193,7 +192,7 @@ class BM25ToolRecommender(ToolRecommender): recalled_tools = [list(self.tools.values())[index] for index in top_indexes] logger.info( - f"Recalled tools: \n{[tool.name for tool in recalled_tools]}; Scores: {[doc_scores[index] for index in top_indexes]}" + f"Recalled tools: \n{[tool.name for tool in recalled_tools]}; Scores: {[np.round(doc_scores[index], 4) for index in top_indexes]}" ) return recalled_tools diff --git a/metagpt/tools/tool_registry.py b/metagpt/tools/tool_registry.py index 11269cb0f..50875e235 100644 --- a/metagpt/tools/tool_registry.py +++ b/metagpt/tools/tool_registry.py @@ -10,14 +10,17 @@ from __future__ import annotations import inspect import os from collections import defaultdict -from typing import Union +from pathlib import Path import yaml from pydantic import BaseModel from metagpt.const import TOOL_SCHEMA_PATH from metagpt.logs import logger -from metagpt.tools.tool_convert import convert_code_to_tool_schema +from metagpt.tools.tool_convert import ( + convert_code_to_tool_schema, + convert_code_to_tool_schema_ast, +) from metagpt.tools.tool_data_type import Tool, ToolSchema @@ -27,21 +30,23 @@ class ToolRegistry(BaseModel): def register_tool( self, - tool_name, - tool_path, - schema_path="", - tool_code="", - tags=None, - tool_source_object=None, - include_functions=None, - verbose=False, + tool_name: str, + tool_path: str, + schemas: dict = None, + schema_path: str = "", + tool_code: str = "", + tags: list[str] = None, + tool_source_object=None, # can be any classes or functions + include_functions: list[str] = None, + verbose: bool = False, ): if self.has_tool(tool_name): return schema_path = schema_path or TOOL_SCHEMA_PATH / f"{tool_name}.yml" - schemas = make_schema(tool_source_object, include_functions, schema_path) + if not schemas: + schemas = make_schema(tool_source_object, include_functions, schema_path) if not schemas: return @@ -117,9 +122,6 @@ def make_schema(tool_source_object, include, path): schema = convert_code_to_tool_schema(tool_source_object, include=include) with open(path, "w", encoding="utf-8") as f: yaml.dump(schema, f, sort_keys=False) - # import json - # with open(str(path).replace("yml", "json"), "w", encoding="utf-8") as f: - # json.dump(schema, f, ensure_ascii=False, indent=4) except Exception as e: schema = {} logger.error(f"Fail to make schema: {e}") @@ -127,15 +129,49 @@ def make_schema(tool_source_object, include, path): return schema -def validate_tool_names(tools: Union[list[str], str]) -> str: +def validate_tool_names(tools: list[str]) -> dict[str, Tool]: assert isinstance(tools, list), "tools must be a list of str" valid_tools = {} for key in tools: - # one can define either tool names or tool type names, take union to get the whole set - if TOOL_REGISTRY.has_tool(key): + # one can define either tool names OR tool tags OR tool path, take union to get the whole set + # if tool paths are provided, they will be registered on the fly + if os.path.isdir(key) or os.path.isfile(key): + valid_tools.update(register_tools_from_path(key)) + elif TOOL_REGISTRY.has_tool(key): valid_tools.update({key: TOOL_REGISTRY.get_tool(key)}) elif TOOL_REGISTRY.has_tool_tag(key): valid_tools.update(TOOL_REGISTRY.get_tools_by_tag(key)) else: logger.warning(f"invalid tool name or tool type name: {key}, skipped") return valid_tools + + +def register_tools_from_file(file_path) -> dict[str, Tool]: + file_name = Path(file_path).name + if not file_name.endswith(".py") or file_name == "setup.py" or file_name.startswith("test"): + return {} + registered_tools = {} + code = Path(file_path).read_text(encoding="utf-8") + tool_schemas = convert_code_to_tool_schema_ast(code) + for name, schemas in tool_schemas.items(): + tool_code = schemas.pop("code", "") + TOOL_REGISTRY.register_tool( + tool_name=name, + tool_path=file_path, + schemas=schemas, + tool_code=tool_code, + ) + registered_tools.update({name: TOOL_REGISTRY.get_tool(name)}) + return registered_tools + + +def register_tools_from_path(path) -> dict[str, Tool]: + tools_registered = {} + if os.path.isfile(path): + tools_registered.update(register_tools_from_file(path)) + elif os.path.isdir(path): + for root, _, files in os.walk(path): + for file in files: + file_path = os.path.join(root, file) + tools_registered.update(register_tools_from_file(file_path)) + return tools_registered diff --git a/metagpt/utils/parse_docstring.py b/metagpt/utils/parse_docstring.py index 63c0e6890..5df4d6671 100644 --- a/metagpt/utils/parse_docstring.py +++ b/metagpt/utils/parse_docstring.py @@ -3,7 +3,7 @@ from typing import Tuple def remove_spaces(text): - return re.sub(r"\s+", " ", text).strip() + return re.sub(r"\s+", " ", text).strip() if text else "" class DocstringParser: diff --git a/requirements.txt b/requirements.txt index d150d61f3..83962b21b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -69,5 +69,4 @@ imap_tools==1.5.0 # Used by metagpt/tools/libs/email_login.py qianfan==0.3.2 dashscope==1.14.1 rank-bm25==0.2.2 # for tool recommendation -jieba==0.42.1 # for tool recommendation -gymnasium==0.29.1 \ No newline at end of file +gymnasium==0.29.1 diff --git a/tests/metagpt/tools/libs/test_terminal.py b/tests/metagpt/tools/libs/test_terminal.py new file mode 100644 index 000000000..97c33b977 --- /dev/null +++ b/tests/metagpt/tools/libs/test_terminal.py @@ -0,0 +1,15 @@ +from metagpt.const import DATA_PATH, METAGPT_ROOT +from metagpt.tools.libs.terminal import Terminal + + +def test_terminal(): + terminal = Terminal() + + terminal.run_command(f"cd {METAGPT_ROOT}") + output = terminal.run_command("pwd") + assert output.strip() == str(METAGPT_ROOT) + + # pwd now should be METAGPT_ROOT, cd data should land in DATA_PATH + terminal.run_command("cd data") + output = terminal.run_command("pwd") + assert output.strip() == str(DATA_PATH) diff --git a/tests/metagpt/tools/test_tool_convert.py b/tests/metagpt/tools/test_tool_convert.py index 061a619ce..4798d32b0 100644 --- a/tests/metagpt/tools/test_tool_convert.py +++ b/tests/metagpt/tools/test_tool_convert.py @@ -2,7 +2,10 @@ from typing import Literal, Union import pandas as pd -from metagpt.tools.tool_convert import convert_code_to_tool_schema +from metagpt.tools.tool_convert import ( + convert_code_to_tool_schema, + convert_code_to_tool_schema_ast, +) class DummyClass: @@ -128,3 +131,91 @@ def test_convert_code_to_tool_schema_function(): def test_convert_code_to_tool_schema_async_function(): schema = convert_code_to_tool_schema(dummy_async_fn) assert schema.get("type") == "async_function" + + +TEST_CODE_FILE_TEXT = ''' +import pandas as pd # imported obj should not be parsed +from some_module1 import some_imported_function, SomeImportedClass # imported obj should not be parsed +from ..some_module2 import some_imported_function2 # relative import should not result in an error + +class MyClass: + """This is a MyClass docstring.""" + def __init__(self, arg1): + """This is the constructor docstring.""" + self.arg1 = arg1 + + def my_method(self, arg2: Union[list[str], str], arg3: pd.DataFrame, arg4: int = 1, arg5: Literal["a","b","c"] = "a") -> Tuple[int, str]: + """ + This is a method docstring. + + Args: + arg2 (Union[list[str], str]): A union of a list of strings and a string. + ... + + Returns: + Tuple[int, str]: A tuple of an integer and a string. + """ + return self.arg4 + arg5 + + async def my_async_method(self, some_arg) -> str: + return "hi" + + def _private_method(self): # private should not be parsed + return "private" + +def my_function(arg1, arg2) -> dict: + """This is a function docstring.""" + return arg1 + arg2 + +def my_async_function(arg1, arg2) -> dict: + return arg1 + arg2 + +def _private_function(): # private should not be parsed + return "private" +''' + + +def test_convert_code_to_tool_schema_ast(): + expected = { + "MyClass": { + "type": "class", + "description": "This is a MyClass docstring.", + "methods": { + "__init__": { + "type": "function", + "description": "This is the constructor docstring.", + "signature": "(self, arg1)", + "parameters": "", + }, + "my_method": { + "type": "function", + "description": "This is a method docstring. ", + "signature": "(self, arg2: Union[list[str], str], arg3: pd.DataFrame, arg4: int = 1, arg5: Literal['a', 'b', 'c'] = 'a') -> Tuple[int, str]", + "parameters": "Args: arg2 (Union[list[str], str]): A union of a list of strings and a string. ... Returns: Tuple[int, str]: A tuple of an integer and a string.", + }, + "my_async_method": { + "type": "async_function", + "description": "", + "signature": "(self, some_arg) -> str", + "parameters": "", + }, + }, + "code": 'class MyClass:\n """This is a MyClass docstring."""\n def __init__(self, arg1):\n """This is the constructor docstring."""\n self.arg1 = arg1\n\n def my_method(self, arg2: Union[list[str], str], arg3: pd.DataFrame, arg4: int = 1, arg5: Literal["a","b","c"] = "a") -> Tuple[int, str]:\n """\n This is a method docstring.\n \n Args:\n arg2 (Union[list[str], str]): A union of a list of strings and a string.\n ...\n \n Returns:\n Tuple[int, str]: A tuple of an integer and a string.\n """\n return self.arg4 + arg5\n \n async def my_async_method(self, some_arg) -> str:\n return "hi"\n \n def _private_method(self): # private should not be parsed\n return "private"', + }, + "my_function": { + "type": "function", + "description": "This is a function docstring.", + "signature": "(arg1, arg2) -> dict", + "parameters": "", + "code": 'def my_function(arg1, arg2) -> dict:\n """This is a function docstring."""\n return arg1 + arg2', + }, + "my_async_function": { + "type": "function", + "description": "", + "signature": "(arg1, arg2) -> dict", + "parameters": "", + "code": "def my_async_function(arg1, arg2) -> dict:\n return arg1 + arg2", + }, + } + schemas = convert_code_to_tool_schema_ast(TEST_CODE_FILE_TEXT) + assert schemas == expected