add docstring parser

This commit is contained in:
yzlin 2024-01-22 14:58:06 +08:00
parent 0f245f530e
commit 9b3987ff29
2 changed files with 101 additions and 27 deletions

View file

@ -1,9 +1,6 @@
import inspect
import re
def remove_spaces(text):
return re.sub(r"\s+", " ", text)
from metagpt.utils.parse_docstring import GoogleDocstringParser, remove_spaces
def convert_code_to_tool_schema(obj, include: list[str] = []):
@ -34,45 +31,35 @@ def docstring_to_schema(docstring: str):
if docstring is None:
return {}
parser = GoogleDocstringParser(docstring=docstring)
# 匹配简介部分
description_match = re.search(r"^(.*?)(?:Args:|Returns:|Raises:|$)", docstring, re.DOTALL)
description = remove_spaces(description_match.group(1)) if description_match else ""
description = parser.parse_desc()
# 匹配Args部分
args_match = re.search(r"Args:\s*(.*?)(?:Returns:|Raises:|$)", docstring, re.DOTALL)
_args = args_match.group(1).strip() if args_match else ""
# variable_pattern = re.compile(r"(\w+)\s*\((.*?)\):\s*(.*)")
variable_pattern = re.compile(
r"(\w+)\s*\((.*?)\):\s*(.*?)(?=\n\s*\w+\s*\(|\Z)", re.DOTALL
) # (?=\n\w+\s*\(|\Z) is to assert that what follows is either the start of the next parameter (indicated by a newline, some word characters, and an opening parenthesis) or the end of the string (\Z).
params = variable_pattern.findall(_args)
params = parser.parse_params()
parameter_schema = {"properties": {}, "required": []}
for param in params:
param_name, param_type, param_desc = param
# check required or optional
if "optional" in param_type:
param_type = param_type.replace(", optional", "")
else:
is_optional, param_type = parser.check_and_parse_optional(param_type)
if not is_optional:
parameter_schema["required"].append(param_name)
# type and desc
param_dict = {"type": param_type, "description": remove_spaces(param_desc)}
# match Default for optional args
default_val = re.search(r"Defaults to (.+?)\.", param_desc)
if default_val:
param_dict["default"] = default_val.group(1)
has_default_val, default_val = parser.check_and_parse_default_value(param_desc)
if has_default_val:
param_dict["default"] = default_val
# match Enum
enum_val = re.search(r"Enum: \[(.+?)\]", param_desc)
if enum_val:
param_dict["enum"] = [e.strip() for e in enum_val.group(1).split(",")]
has_enum, enum_vals = parser.check_and_parse_enum(param_desc)
if has_enum:
param_dict["enum"] = enum_vals
# add to parameter schema
parameter_schema["properties"].update({param_name: param_dict})
# 匹配Returns部分
returns_match = re.search(r"Returns:\s*(.*?)(?:Raises:|$)", docstring, re.DOTALL)
returns = returns_match.group(1).strip() if returns_match else ""
return_pattern = re.compile(r"^(.*)\s*:\s*(.*)$")
returns = return_pattern.findall(returns)
returns = parser.parse_returns()
# 构建YAML字典
schema = {