diff --git a/metagpt/tools/tool_convert.py b/metagpt/tools/tool_convert.py index fdb69bfb3..b8377e67a 100644 --- a/metagpt/tools/tool_convert.py +++ b/metagpt/tools/tool_convert.py @@ -1,9 +1,6 @@ import inspect -import re - -def remove_spaces(text): - return re.sub(r"\s+", " ", text) +from metagpt.utils.parse_docstring import GoogleDocstringParser, remove_spaces def convert_code_to_tool_schema(obj, include: list[str] = []): @@ -34,45 +31,35 @@ def docstring_to_schema(docstring: str): if docstring is None: return {} + parser = GoogleDocstringParser(docstring=docstring) + # 匹配简介部分 - description_match = re.search(r"^(.*?)(?:Args:|Returns:|Raises:|$)", docstring, re.DOTALL) - description = remove_spaces(description_match.group(1)) if description_match else "" + description = parser.parse_desc() # 匹配Args部分 - args_match = re.search(r"Args:\s*(.*?)(?:Returns:|Raises:|$)", docstring, re.DOTALL) - _args = args_match.group(1).strip() if args_match else "" - # variable_pattern = re.compile(r"(\w+)\s*\((.*?)\):\s*(.*)") - variable_pattern = re.compile( - r"(\w+)\s*\((.*?)\):\s*(.*?)(?=\n\s*\w+\s*\(|\Z)", re.DOTALL - ) # (?=\n\w+\s*\(|\Z) is to assert that what follows is either the start of the next parameter (indicated by a newline, some word characters, and an opening parenthesis) or the end of the string (\Z). - - params = variable_pattern.findall(_args) + params = parser.parse_params() parameter_schema = {"properties": {}, "required": []} for param in params: param_name, param_type, param_desc = param # check required or optional - if "optional" in param_type: - param_type = param_type.replace(", optional", "") - else: + is_optional, param_type = parser.check_and_parse_optional(param_type) + if not is_optional: parameter_schema["required"].append(param_name) # type and desc param_dict = {"type": param_type, "description": remove_spaces(param_desc)} # match Default for optional args - default_val = re.search(r"Defaults to (.+?)\.", param_desc) - if default_val: - param_dict["default"] = default_val.group(1) + has_default_val, default_val = parser.check_and_parse_default_value(param_desc) + if has_default_val: + param_dict["default"] = default_val # match Enum - enum_val = re.search(r"Enum: \[(.+?)\]", param_desc) - if enum_val: - param_dict["enum"] = [e.strip() for e in enum_val.group(1).split(",")] + has_enum, enum_vals = parser.check_and_parse_enum(param_desc) + if has_enum: + param_dict["enum"] = enum_vals # add to parameter schema parameter_schema["properties"].update({param_name: param_dict}) # 匹配Returns部分 - returns_match = re.search(r"Returns:\s*(.*?)(?:Raises:|$)", docstring, re.DOTALL) - returns = returns_match.group(1).strip() if returns_match else "" - return_pattern = re.compile(r"^(.*)\s*:\s*(.*)$") - returns = return_pattern.findall(returns) + returns = parser.parse_returns() # 构建YAML字典 schema = { diff --git a/metagpt/utils/parse_docstring.py b/metagpt/utils/parse_docstring.py new file mode 100644 index 000000000..970257676 --- /dev/null +++ b/metagpt/utils/parse_docstring.py @@ -0,0 +1,87 @@ +import re +from typing import Tuple + +from pydantic import BaseModel + + +def remove_spaces(text): + return re.sub(r"\s+", " ", text) + + +class DocstringParser(BaseModel): + docstring: str + + def parse_desc(self) -> str: + """Parse and return the description from the docstring.""" + + def parse_params(self) -> list[Tuple[str, str, str]]: + """Parse and return the parameters from the docstring. + + Returns: + list[Tuple[str, str, str]]: A list of input paramter info. Each info is a triple of (param name, param type, param description) + """ + + def parse_returns(self) -> list[Tuple[str, str]]: + """Parse and return the return information from the docstring. + + Returns: + list[Tuple[str, str, str]]: A list of output info. Each info is a tuple of (return type, return description) + """ + + @staticmethod + def check_and_parse_optional(param_type: str) -> Tuple[bool, str]: + """Check if a parameter is optional and return a processed param_type rid of the optionality info if so""" + + @staticmethod + def check_and_parse_default_value(param_desc: str) -> Tuple[bool, str]: + """Check if a parameter has a default value and return the default value if so""" + + @staticmethod + def check_and_parse_enum(param_desc: str) -> Tuple[bool, str]: + """Check if a parameter description includes an enum and return enum values if so""" + + +class reSTDocstringParser(DocstringParser): + """A parser for reStructuredText (reST) docstring""" + + +class GoogleDocstringParser(DocstringParser): + """A parser for Google-stype docstring""" + + docstring: str + + def parse_desc(self) -> str: + description_match = re.search(r"^(.*?)(?:Args:|Returns:|Raises:|$)", self.docstring, re.DOTALL) + description = remove_spaces(description_match.group(1)) if description_match else "" + return description + + def parse_params(self) -> list[Tuple[str, str, str]]: + args_match = re.search(r"Args:\s*(.*?)(?:Returns:|Raises:|$)", self.docstring, re.DOTALL) + _args = args_match.group(1).strip() if args_match else "" + # variable_pattern = re.compile(r"(\w+)\s*\((.*?)\):\s*(.*)") + variable_pattern = re.compile( + r"(\w+)\s*\((.*?)\):\s*(.*?)(?=\n\s*\w+\s*\(|\Z)", re.DOTALL + ) # (?=\n\w+\s*\(|\Z) is to assert that what follows is either the start of the next parameter (indicated by a newline, some word characters, and an opening parenthesis) or the end of the string (\Z). + params = variable_pattern.findall(_args) + return params + + def parse_returns(self) -> list[Tuple[str, str]]: + returns_match = re.search(r"Returns:\s*(.*?)(?:Raises:|$)", self.docstring, re.DOTALL) + returns = returns_match.group(1).strip() if returns_match else "" + return_pattern = re.compile(r"^(.*)\s*:\s*(.*)$") + returns = return_pattern.findall(returns) + return returns + + @staticmethod + def check_and_parse_optional(param_type: str) -> Tuple[bool, str]: + return "optional" in param_type, param_type.replace(", optional", "") + + @staticmethod + def check_and_parse_default_value(param_desc: str) -> Tuple[bool, str]: + default_val = re.search(r"Defaults to (.+?)\.", param_desc) + return (True, default_val.group(1)) if default_val else (False, "") + + @staticmethod + def check_and_parse_enum(param_desc: str) -> Tuple[bool, str]: + enum_val = re.search(r"Enum: \[(.+?)\]", param_desc) + return (True, [e.strip() for e in enum_val.group(1).split(",")]) if enum_val else (False, [])