From 8b1fff22155fb0844376e76f1eaa55d175ae015e Mon Sep 17 00:00:00 2001 From: shenchucheng Date: Mon, 7 Aug 2023 20:58:01 +0800 Subject: [PATCH] add docs for write_docstring.py and parse python code with retry --- metagpt/actions/write_docstring.py | 27 ++++++++++++++++++++--- metagpt/utils/common.py | 20 ++++++++++++++++- tests/metagpt/utils/test_output_parser.py | 18 ++++++++++++++- 3 files changed, 60 insertions(+), 5 deletions(-) diff --git a/metagpt/actions/write_docstring.py b/metagpt/actions/write_docstring.py index db1928872..5c7815793 100644 --- a/metagpt/actions/write_docstring.py +++ b/metagpt/actions/write_docstring.py @@ -1,5 +1,27 @@ +"""Code Docstring Generator. + +This script provides a tool to automatically generate docstrings for Python code. It uses the specified style to create +docstrings for the given code and system text. + +Usage: + python3 -m metagpt.actions.write_docstring [--overwrite] [--style=] + +Arguments: + filename The path to the Python file for which you want to generate docstrings. + +Options: + --overwrite If specified, overwrite the original file with the code containing docstrings. + --style= Specify the style of the generated docstrings. + Valid values: 'google', 'numpy', or 'sphinx'. + Default: 'google' + +Example: + python3 -m metagpt.actions.write_docstring startup.py --overwrite False --style=numpy + +This script uses the 'fire' library to create a command-line interface. It generates docstrings for the given Python code using +the specified docstring style and adds them to the code. +""" import ast -import contextlib from typing import Literal from metagpt.actions.action import Action @@ -157,8 +179,7 @@ class WriteDocstring(Action): system_text = system_text.format(style=style, example=_python_docstring_style[style]) simplified_code = _simplify_python_code(code) documented_code = await self._aask(f"```python\n{simplified_code}\n```", [system_text]) - with contextlib.suppress(Exception): - documented_code = OutputParser.parse_code(documented_code) + documented_code = OutputParser.parse_python_code(documented_code) return merge_docstring(code, documented_code) diff --git a/metagpt/utils/common.py b/metagpt/utils/common.py index d695fd4c4..7f090cf63 100644 --- a/metagpt/utils/common.py +++ b/metagpt/utils/common.py @@ -6,6 +6,7 @@ @File : common.py """ import ast +import contextlib import inspect import os import re @@ -49,7 +50,7 @@ class OutputParser: @classmethod def parse_code(cls, text: str, lang: str = "") -> str: - pattern = rf'```{lang}.*?\s+(.*)```' + pattern = rf'```{lang}.*?\s+(.*?)```' match = re.search(pattern, text, re.DOTALL) if match: code = match.group(1) @@ -78,6 +79,23 @@ class OutputParser: else: tasks = text.split("\n") return tasks + + @staticmethod + def parse_python_code(text: str) -> str: + for pattern in ( + r'(.*?```python.*?\s+)?(?P.*)(```.*?)', + r'(.*?```python.*?\s+)?(?P.*)', + ): + match = re.search(pattern, text, re.DOTALL) + if not match: + continue + code = match.group("code") + if not code: + continue + with contextlib.suppress(Exception): + ast.parse(code) + return code + raise ValueError("Invalid python code") @classmethod def parse_data(cls, data): diff --git a/tests/metagpt/utils/test_output_parser.py b/tests/metagpt/utils/test_output_parser.py index 155297860..c56cff6fa 100644 --- a/tests/metagpt/utils/test_output_parser.py +++ b/tests/metagpt/utils/test_output_parser.py @@ -19,7 +19,7 @@ def test_parse_blocks(): def test_parse_code(): - test_text = "```python\nprint('Hello, world!')\n```" + test_text = "```python\nprint('Hello, world!')```" expected_result = "print('Hello, world!')" assert OutputParser.parse_code(test_text, 'python') == expected_result @@ -27,6 +27,22 @@ def test_parse_code(): OutputParser.parse_code(test_text, 'java') +def test_parse_python_code(): + expected_result = "print('Hello, world!')" + assert OutputParser.parse_python_code("```python\nprint('Hello, world!')```") == expected_result + assert OutputParser.parse_python_code("```python\nprint('Hello, world!')") == expected_result + assert OutputParser.parse_python_code("print('Hello, world!')") == expected_result + assert OutputParser.parse_python_code("print('Hello, world!')```") == expected_result + assert OutputParser.parse_python_code("print('Hello, world!')```") == expected_result + expected_result = "print('```Hello, world!```')" + assert OutputParser.parse_python_code("```python\nprint('```Hello, world!```')```") == expected_result + assert OutputParser.parse_python_code("The code is: ```python\nprint('```Hello, world!```')```") == expected_result + assert OutputParser.parse_python_code("xxx.\n```python\nprint('```Hello, world!```')```\nxxx") == expected_result + + with pytest.raises(ValueError): + OutputParser.parse_python_code("xxx =") + + def test_parse_str(): test_text = "name = 'Alice'" expected_result = 'Alice'