Merge pull request #124 from shenchucheng/main

add write docstring action
This commit is contained in:
stellaHSR 2023-08-08 11:31:12 +08:00 committed by GitHub
commit 930d18962f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 586 additions and 2 deletions

View file

@ -0,0 +1,214 @@
"""Code Docstring Generator.
This script provides a tool to automatically generate docstrings for Python code. It uses the specified style to create
docstrings for the given code and system text.
Usage:
python3 -m metagpt.actions.write_docstring <filename> [--overwrite] [--style=<docstring_style>]
Arguments:
filename The path to the Python file for which you want to generate docstrings.
Options:
--overwrite If specified, overwrite the original file with the code containing docstrings.
--style=<docstring_style> Specify the style of the generated docstrings.
Valid values: 'google', 'numpy', or 'sphinx'.
Default: 'google'
Example:
python3 -m metagpt.actions.write_docstring startup.py --overwrite False --style=numpy
This script uses the 'fire' library to create a command-line interface. It generates docstrings for the given Python code using
the specified docstring style and adds them to the code.
"""
import ast
from typing import Literal
from metagpt.actions.action import Action
from metagpt.utils.common import OutputParser
from metagpt.utils.pycst import merge_docstring
PYTHON_DOCSTRING_SYSTEM = '''### Requirements
1. Add docstrings to the given code following the {style} style.
2. Replace the function body with an Ellipsis object(...) to reduce output.
3. If the types are already annotated, there is no need to include them in the docstring.
4. Extract only class, function or the docstrings for the module parts from the given Python code, avoiding any other text.
### Input Example
```python
def function_with_pep484_type_annotations(param1: int) -> bool:
return isinstance(param1, int)
class ExampleError(Exception):
def __init__(self, msg: str):
self.msg = msg
```
### Output Example
```python
{example}
```
'''
# https://www.sphinx-doc.org/en/master/usage/extensions/napoleon.html
PYTHON_DOCSTRING_EXAMPLE_GOOGLE = '''
def function_with_pep484_type_annotations(param1: int) -> bool:
"""Example function with PEP 484 type annotations.
Extended description of function.
Args:
param1: The first parameter.
Returns:
The return value. True for success, False otherwise.
"""
...
class ExampleError(Exception):
"""Exceptions are documented in the same way as classes.
The __init__ method was documented in the class level docstring.
Args:
msg: Human readable string describing the exception.
Attributes:
msg: Human readable string describing the exception.
"""
...
'''
PYTHON_DOCSTRING_EXAMPLE_NUMPY = '''
def function_with_pep484_type_annotations(param1: int) -> bool:
"""
Example function with PEP 484 type annotations.
Extended description of function.
Parameters
----------
param1
The first parameter.
Returns
-------
bool
The return value. True for success, False otherwise.
"""
...
class ExampleError(Exception):
"""
Exceptions are documented in the same way as classes.
The __init__ method was documented in the class level docstring.
Parameters
----------
msg
Human readable string describing the exception.
Attributes
----------
msg
Human readable string describing the exception.
"""
...
'''
PYTHON_DOCSTRING_EXAMPLE_SPHINX = '''
def function_with_pep484_type_annotations(param1: int) -> bool:
"""Example function with PEP 484 type annotations.
Extended description of function.
:param param1: The first parameter.
:type param1: int
:return: The return value. True for success, False otherwise.
:rtype: bool
"""
...
class ExampleError(Exception):
"""Exceptions are documented in the same way as classes.
The __init__ method was documented in the class level docstring.
:param msg: Human-readable string describing the exception.
:type msg: str
"""
...
'''
_python_docstring_style = {
"google": PYTHON_DOCSTRING_EXAMPLE_GOOGLE.strip(),
"numpy": PYTHON_DOCSTRING_EXAMPLE_NUMPY.strip(),
"sphinx": PYTHON_DOCSTRING_EXAMPLE_SPHINX.strip(),
}
class WriteDocstring(Action):
"""This class is used to write docstrings for code.
Attributes:
desc: A string describing the action.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.desc = "Write docstring for code."
async def run(
self, code: str,
system_text: str = PYTHON_DOCSTRING_SYSTEM,
style: Literal["google", "numpy", "sphinx"] = "google",
) -> str:
"""Writes docstrings for the given code and system text in the specified style.
Args:
code: A string of Python code.
system_text: A string of system text.
style: A string specifying the style of the docstring. Can be 'google', 'numpy', or 'sphinx'.
Returns:
The Python code with docstrings added.
"""
system_text = system_text.format(style=style, example=_python_docstring_style[style])
simplified_code = _simplify_python_code(code)
documented_code = await self._aask(f"```python\n{simplified_code}\n```", [system_text])
documented_code = OutputParser.parse_python_code(documented_code)
return merge_docstring(code, documented_code)
def _simplify_python_code(code: str) -> None:
"""Simplifies the given Python code by removing expressions and the last if statement.
Args:
code: A string of Python code.
Returns:
The simplified Python code.
"""
code_tree = ast.parse(code)
code_tree.body = [i for i in code_tree.body if not isinstance(i, ast.Expr)]
if isinstance(code_tree.body[-1], ast.If):
code_tree.body.pop()
return ast.unparse(code_tree)
if __name__ == "__main__":
import fire
async def run(filename: str, overwrite: bool = False, style: Literal["google", "numpy", "sphinx"] = "google"):
with open(filename) as f:
code = f.read()
code = await WriteDocstring().run(code, style=style)
if overwrite:
with open(filename, "w") as f:
f.write(code)
return code
fire.Fire(run)

View file

@ -6,6 +6,7 @@
@File : common.py
"""
import ast
import contextlib
import inspect
import os
import re
@ -78,6 +79,23 @@ class OutputParser:
else:
tasks = text.split("\n")
return tasks
@staticmethod
def parse_python_code(text: str) -> str:
for pattern in (
r'(.*?```python.*?\s+)?(?P<code>.*)(```.*?)',
r'(.*?```python.*?\s+)?(?P<code>.*)',
):
match = re.search(pattern, text, re.DOTALL)
if not match:
continue
code = match.group("code")
if not code:
continue
with contextlib.suppress(Exception):
ast.parse(code)
return code
raise ValueError("Invalid python code")
@classmethod
def parse_data(cls, data):
@ -231,7 +249,8 @@ def print_members(module, indent=0):
elif inspect.ismethod(obj):
print(f'{prefix}Method: {name}')
def parse_recipient(text):
pattern = "## Send To:\s*([A-Za-z]+)\s*?" # hard code for now
pattern = r"## Send To:\s*([A-Za-z]+)\s*?" # hard code for now
recipient = re.search(pattern, text)
return recipient.group(1) if recipient else ""

166
metagpt/utils/pycst.py Normal file
View file

@ -0,0 +1,166 @@
from __future__ import annotations
from typing import Union
import libcst as cst
from libcst._nodes.module import Module
DocstringNode = Union[cst.Module, cst.ClassDef, cst.FunctionDef]
def get_docstring_statement(body: DocstringNode) -> cst.SimpleStatementLine:
"""Extracts the docstring from the body of a node.
Args:
body: The body of a node.
Returns:
The docstring statement if it exists, None otherwise.
"""
if isinstance(body, cst.Module):
body = body.body
else:
body = body.body.body
if not body:
return
statement = body[0]
if not isinstance(statement, cst.SimpleStatementLine):
return
expr = statement
while isinstance(expr, (cst.BaseSuite, cst.SimpleStatementLine)):
if len(expr.body) == 0:
return None
expr = expr.body[0]
if not isinstance(expr, cst.Expr):
return None
val = expr.value
if not isinstance(val, (cst.SimpleString, cst.ConcatenatedString)):
return None
evaluated_value = val.evaluated_value
if isinstance(evaluated_value, bytes):
return None
return statement
class DocstringCollector(cst.CSTVisitor):
"""A visitor class for collecting docstrings from a CST.
Attributes:
stack: A list to keep track of the current path in the CST.
docstrings: A dictionary mapping paths in the CST to their corresponding docstrings.
"""
def __init__(self):
self.stack: list[str] = []
self.docstrings: dict[tuple[str, ...], cst.SimpleStatementLine] = {}
def visit_Module(self, node: cst.Module) -> bool | None:
self.stack.append("")
def leave_Module(self, node: cst.Module) -> None:
return self._leave(node)
def visit_ClassDef(self, node: cst.ClassDef) -> bool | None:
self.stack.append(node.name.value)
def leave_ClassDef(self, node: cst.ClassDef) -> None:
return self._leave(node)
def visit_FunctionDef(self, node: cst.FunctionDef) -> bool | None:
self.stack.append(node.name.value)
def leave_FunctionDef(self, node: cst.FunctionDef) -> None:
return self._leave(node)
def _leave(self, node: DocstringNode) -> None:
key = tuple(self.stack)
self.stack.pop()
if hasattr(node, "decorators") and any(i.decorator.value == "overload" for i in node.decorators):
return
statement = get_docstring_statement(node)
if statement:
self.docstrings[key] = statement
class DocstringTransformer(cst.CSTTransformer):
"""A transformer class for replacing docstrings in a CST.
Attributes:
stack: A list to keep track of the current path in the CST.
docstrings: A dictionary mapping paths in the CST to their corresponding docstrings.
"""
def __init__(
self,
docstrings: dict[tuple[str, ...], cst.SimpleStatementLine],
):
self.stack: list[str] = []
self.docstrings = docstrings
def visit_Module(self, node: cst.Module) -> bool | None:
self.stack.append("")
def leave_Module(self, original_node: Module, updated_node: Module) -> Module:
return self._leave(original_node, updated_node)
def visit_ClassDef(self, node: cst.ClassDef) -> bool | None:
self.stack.append(node.name.value)
def leave_ClassDef(self, original_node: cst.ClassDef, updated_node: cst.ClassDef) -> cst.CSTNode:
return self._leave(original_node, updated_node)
def visit_FunctionDef(self, node: cst.FunctionDef) -> bool | None:
self.stack.append(node.name.value)
def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.CSTNode:
return self._leave(original_node, updated_node)
def _leave(self, original_node: DocstringNode, updated_node: DocstringNode) -> DocstringNode:
key = tuple(self.stack)
self.stack.pop()
if hasattr(updated_node, "decorators") and any((i.decorator.value == "overload") for i in updated_node.decorators):
return updated_node
statement = self.docstrings.get(key)
if not statement:
return updated_node
original_statement = get_docstring_statement(original_node)
if isinstance(updated_node, cst.Module):
body = updated_node.body
if original_statement:
return updated_node.with_changes(body=(statement, *body[1:]))
else:
updated_node = updated_node.with_changes(body=(statement, cst.EmptyLine(), *body))
return updated_node
body = updated_node.body.body[1:] if original_statement else updated_node.body.body
return updated_node.with_changes(body=updated_node.body.with_changes(body=(statement, *body)))
def merge_docstring(code: str, documented_code: str) -> str:
"""Merges the docstrings from the documented code into the original code.
Args:
code: The original code.
documented_code: The documented code.
Returns:
The original code with the docstrings from the documented code.
"""
code_tree = cst.parse_module(code)
documented_code_tree = cst.parse_module(documented_code)
visitor = DocstringCollector()
documented_code_tree.visit(visitor)
transformer = DocstringTransformer(visitor.docstrings)
modified_tree = code_tree.visit(transformer)
return modified_tree.code

View file

@ -35,3 +35,4 @@ tqdm==4.64.0
anthropic==0.3.6
typing-inspect==0.8.0
typing_extensions==4.5.0
libcst==1.0.1

View file

@ -0,0 +1,32 @@
import pytest
from metagpt.actions.write_docstring import WriteDocstring
code = '''
def add_numbers(a: int, b: int):
return a + b
class Person:
def __init__(self, name: str, age: int):
self.name = name
self.age = age
def greet(self):
return f"Hello, my name is {self.name} and I am {self.age} years old."
'''
@pytest.mark.asyncio
@pytest.mark.parametrize(
("style", "part"),
[
("google", "Args:"),
("numpy", "Parameters"),
("sphinx", ":param name:"),
],
ids=["google", "numpy", "sphinx"]
)
async def test_write_docstring(style: str, part: str):
ret = await WriteDocstring().run(code, style=style)
assert part in ret

View file

@ -19,7 +19,7 @@ def test_parse_blocks():
def test_parse_code():
test_text = "```python\nprint('Hello, world!')\n```"
test_text = "```python\nprint('Hello, world!')```"
expected_result = "print('Hello, world!')"
assert OutputParser.parse_code(test_text, 'python') == expected_result
@ -27,6 +27,22 @@ def test_parse_code():
OutputParser.parse_code(test_text, 'java')
def test_parse_python_code():
expected_result = "print('Hello, world!')"
assert OutputParser.parse_python_code("```python\nprint('Hello, world!')```") == expected_result
assert OutputParser.parse_python_code("```python\nprint('Hello, world!')") == expected_result
assert OutputParser.parse_python_code("print('Hello, world!')") == expected_result
assert OutputParser.parse_python_code("print('Hello, world!')```") == expected_result
assert OutputParser.parse_python_code("print('Hello, world!')```") == expected_result
expected_result = "print('```Hello, world!```')"
assert OutputParser.parse_python_code("```python\nprint('```Hello, world!```')```") == expected_result
assert OutputParser.parse_python_code("The code is: ```python\nprint('```Hello, world!```')```") == expected_result
assert OutputParser.parse_python_code("xxx.\n```python\nprint('```Hello, world!```')```\nxxx") == expected_result
with pytest.raises(ValueError):
OutputParser.parse_python_code("xxx =")
def test_parse_str():
test_text = "name = 'Alice'"
expected_result = 'Alice'

View file

@ -0,0 +1,136 @@
from metagpt.utils import pycst
code = '''
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from typing import overload
@overload
def add_numbers(a: int, b: int):
...
@overload
def add_numbers(a: float, b: float):
...
def add_numbers(a: int, b: int):
return a + b
class Person:
def __init__(self, name: str, age: int):
self.name = name
self.age = age
def greet(self):
return f"Hello, my name is {self.name} and I am {self.age} years old."
'''
documented_code = '''
"""
This is an example module containing a function and a class definition.
"""
def add_numbers(a: int, b: int):
"""This function is used to add two numbers and return the result.
Parameters:
a: The first integer.
b: The second integer.
Returns:
int: The sum of the two numbers.
"""
return a + b
class Person:
"""This class represents a person's information, including name and age.
Attributes:
name: The person's name.
age: The person's age.
"""
def __init__(self, name: str, age: int):
"""Creates a new instance of the Person class.
Parameters:
name: The person's name.
age: The person's age.
"""
...
def greet(self):
"""
Returns a greeting message including the name and age.
Returns:
str: The greeting message.
"""
...
'''
merged_code = '''
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
This is an example module containing a function and a class definition.
"""
from typing import overload
@overload
def add_numbers(a: int, b: int):
...
@overload
def add_numbers(a: float, b: float):
...
def add_numbers(a: int, b: int):
"""This function is used to add two numbers and return the result.
Parameters:
a: The first integer.
b: The second integer.
Returns:
int: The sum of the two numbers.
"""
return a + b
class Person:
"""This class represents a person's information, including name and age.
Attributes:
name: The person's name.
age: The person's age.
"""
def __init__(self, name: str, age: int):
"""Creates a new instance of the Person class.
Parameters:
name: The person's name.
age: The person's age.
"""
self.name = name
self.age = age
def greet(self):
"""
Returns a greeting message including the name and age.
Returns:
str: The greeting message.
"""
return f"Hello, my name is {self.name} and I am {self.age} years old."
'''
def test_merge_docstring():
data = pycst.merge_docstring(code, documented_code)
print(data)
assert data == merged_code