diff --git a/metagpt/actions/write_docstring.py b/metagpt/actions/write_docstring.py index 091bd0a82..370d020de 100644 --- a/metagpt/actions/write_docstring.py +++ b/metagpt/actions/write_docstring.py @@ -8,10 +8,9 @@ from metagpt.utils.pycst import merge_docstring PYTHON_DOCSTRING_SYSTEM = '''### Requirements 1. Add docstrings to the given code following the {style} style. -2. Remove all private members whose names start with an underscore, such as `_test` and `__init__`. -3. Replace the function body with an Ellipsis object(...) to reduce output. -4. If the types are already annotated, there is no need to include them in the docstring. -5. Only output Python code and avoid including any other text. +2. Replace the function body with an Ellipsis object(...) to reduce output. +3. If the types are already annotated, there is no need to include them in the docstring. +4. Extract only class, function or the docstrings for the module parts from the given Python code, avoiding any other text. ### Input Example ```python @@ -128,6 +127,11 @@ _python_docstring_style = { class WriteDocstring(Action): + """This class is used to write docstrings for code. + + Attributes: + desc: A string describing the action. + """ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -138,15 +142,33 @@ class WriteDocstring(Action): system_text: str = PYTHON_DOCSTRING_SYSTEM, style: Literal["google", "numpy", "sphinx"] = "google", ) -> str: + """Writes docstrings for the given code and system text in the specified style. + + Args: + code: A string of Python code. + system_text: A string of system text. + style: A string specifying the style of the docstring. Can be 'google', 'numpy', or 'sphinx'. + + Returns: + The Python code with docstrings added. + """ system_text = system_text.format(style=style, example=_python_docstring_style[style]) simplified_code = _simplify_python_code(code) - documented_code = await self._aask(simplified_code, [system_text]) + documented_code = await self._aask(f"```python\n{simplified_code}\n```", [system_text]) with contextlib.suppress(Exception): documented_code = OutputParser.parse_code(documented_code) return merge_docstring(code, documented_code) def _simplify_python_code(code: str) -> None: + """Simplifies the given Python code by removing expressions and the last if statement. + + Args: + code: A string of Python code. + + Returns: + The simplified Python code. + """ code_tree = ast.parse(code) code_tree.body = [i for i in code_tree.body if not isinstance(i, ast.Expr)] if isinstance(code_tree.body[-1], ast.If): diff --git a/metagpt/utils/common.py b/metagpt/utils/common.py index 0cd73ec0b..d695fd4c4 100644 --- a/metagpt/utils/common.py +++ b/metagpt/utils/common.py @@ -49,7 +49,7 @@ class OutputParser: @classmethod def parse_code(cls, text: str, lang: str = "") -> str: - pattern = rf'```{lang}.*?\s+(.*?)```' + pattern = rf'```{lang}.*?\s+(.*)```' match = re.search(pattern, text, re.DOTALL) if match: code = match.group(1) @@ -231,7 +231,8 @@ def print_members(module, indent=0): elif inspect.ismethod(obj): print(f'{prefix}Method: {name}') + def parse_recipient(text): - pattern = "## Send To:\s*([A-Za-z]+)\s*?" # hard code for now + pattern = r"## Send To:\s*([A-Za-z]+)\s*?" # hard code for now recipient = re.search(pattern, text) return recipient.group(1) if recipient else "" diff --git a/metagpt/utils/pycst.py b/metagpt/utils/pycst.py index c2eb532ab..afd85a547 100644 --- a/metagpt/utils/pycst.py +++ b/metagpt/utils/pycst.py @@ -137,16 +137,13 @@ class DocstringTransformer(cst.CSTTransformer): if isinstance(updated_node, cst.Module): body = updated_node.body if original_statement: - return updated_node.with_changes(body=(body[0], statement, *body[1:])) + return updated_node.with_changes(body=(statement, *body[1:])) else: updated_node = updated_node.with_changes(body=(statement, cst.EmptyLine(), *body)) return updated_node - body = updated_node.body.body - if original_statement: - return updated_node.with_changes(body=updated_node.body.with_changes(body=(body[0], statement, *body[1:]))) - else: - return updated_node.with_changes(body=updated_node.body.with_changes(body=(statement, *body))) + body = updated_node.body.body[1:] if original_statement else updated_node.body.body + return updated_node.with_changes(body=updated_node.body.with_changes(body=(statement, *body))) def merge_docstring(code: str, documented_code: str) -> str: