mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-06-08 15:05:17 +02:00
add write docstring action
This commit is contained in:
parent
a5cb2fdd48
commit
bec5778dd0
5 changed files with 507 additions and 0 deletions
169
metagpt/actions/write_docstring.py
Normal file
169
metagpt/actions/write_docstring.py
Normal file
|
|
@ -0,0 +1,169 @@
|
|||
import ast
|
||||
import contextlib
|
||||
from typing import Literal
|
||||
|
||||
from metagpt.actions.action import Action
|
||||
from metagpt.utils.common import OutputParser
|
||||
from metagpt.utils.pycst import merge_docstring
|
||||
|
||||
PYTHON_DOCSTRING_SYSTEM = '''### Requirements
|
||||
1. Add docstrings to the given code following the {style} style.
|
||||
2. Remove all private members whose names start with an underscore, such as `_test` and `__init__`.
|
||||
3. Replace the function body with an Ellipsis object(...) to reduce output.
|
||||
4. If the types are already annotated, there is no need to include them in the docstring.
|
||||
5. Only output Python code and avoid including any other text.
|
||||
|
||||
### Input Example
|
||||
```python
|
||||
def function_with_pep484_type_annotations(param1: int) -> bool:
|
||||
return isinstanc(param1, int)
|
||||
|
||||
class ExampleError(Exception):
|
||||
def __init__(self, msg: str):
|
||||
self.msg = msg
|
||||
```
|
||||
|
||||
### Output Example
|
||||
```python{example}```
|
||||
'''
|
||||
|
||||
# https://www.sphinx-doc.org/en/master/usage/extensions/napoleon.html
|
||||
|
||||
PYTHON_DOCSTRING_EXAMPLE_GOOGLE = '''
|
||||
def function_with_pep484_type_annotations(param1: int) -> bool:
|
||||
"""Example function with PEP 484 type annotations.
|
||||
|
||||
Extended description of function.
|
||||
|
||||
Args:
|
||||
param1: The first parameter.
|
||||
|
||||
Returns:
|
||||
The return value. True for success, False otherwise.
|
||||
"""
|
||||
...
|
||||
|
||||
class ExampleError(Exception):
|
||||
"""Exceptions are documented in the same way as classes.
|
||||
|
||||
The __init__ method was documented in the class level docstring.
|
||||
|
||||
Args:
|
||||
msg: Human readable string describing the exception.
|
||||
|
||||
Attributes:
|
||||
msg: Human readable string describing the exception.
|
||||
"""
|
||||
...
|
||||
'''
|
||||
|
||||
PYTHON_DOCSTRING_EXAMPLE_NUMPY = '''
|
||||
def function_with_pep484_type_annotations(param1: int) -> bool:
|
||||
"""
|
||||
Example function with PEP 484 type annotations.
|
||||
|
||||
Extended description of function.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
param1
|
||||
The first parameter.
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
The return value. True for success, False otherwise.
|
||||
"""
|
||||
...
|
||||
|
||||
class ExampleError(Exception):
|
||||
"""
|
||||
Exceptions are documented in the same way as classes.
|
||||
|
||||
The __init__ method was documented in the class level docstring.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
msg
|
||||
Human readable string describing the exception.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
msg
|
||||
Human readable string describing the exception.
|
||||
"""
|
||||
...
|
||||
'''
|
||||
|
||||
PYTHON_DOCSTRING_EXAMPLE_SPHINX = '''
|
||||
def function_with_pep484_type_annotations(param1: int) -> bool:
|
||||
"""Example function with PEP 484 type annotations.
|
||||
|
||||
Extended description of function.
|
||||
|
||||
:param param1: The first parameter.
|
||||
:type param1: int
|
||||
|
||||
:return: The return value. True for success, False otherwise.
|
||||
:rtype: bool
|
||||
"""
|
||||
...
|
||||
|
||||
class ExampleError(Exception):
|
||||
"""Exceptions are documented in the same way as classes.
|
||||
|
||||
The __init__ method was documented in the class level docstring.
|
||||
|
||||
:param msg: Human-readable string describing the exception.
|
||||
:type msg: str
|
||||
"""
|
||||
...
|
||||
'''
|
||||
|
||||
_python_docstring_style = {
|
||||
"google": PYTHON_DOCSTRING_EXAMPLE_GOOGLE,
|
||||
"numpy": PYTHON_DOCSTRING_EXAMPLE_NUMPY,
|
||||
"sphinx": PYTHON_DOCSTRING_EXAMPLE_SPHINX,
|
||||
}
|
||||
|
||||
|
||||
class WriteDocstring(Action):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.desc = "Write docstring for code."
|
||||
|
||||
async def run(
|
||||
self, code: str,
|
||||
system_text: str = PYTHON_DOCSTRING_SYSTEM,
|
||||
style: Literal["google", "numpy", "sphinx"] = "google",
|
||||
) -> str:
|
||||
system_text = system_text.format(style=style, example=_python_docstring_style[style])
|
||||
simplified_code = _simplify_python_code(code)
|
||||
documented_code = await self._aask(simplified_code, [system_text])
|
||||
with contextlib.suppress(Exception):
|
||||
documented_code = OutputParser.parse_code(documented_code)
|
||||
return merge_docstring(code, documented_code)
|
||||
|
||||
|
||||
def _simplify_python_code(code: str) -> None:
|
||||
code_tree = ast.parse(code)
|
||||
code_tree.body = [i for i in code_tree.body if not isinstance(i, ast.Expr)]
|
||||
if isinstance(code_tree.body[-1], ast.If):
|
||||
code_tree.body.pop()
|
||||
return ast.unparse(code_tree)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import fire
|
||||
|
||||
async def run(filename: str, overwrite: bool = False, style: Literal["google", "numpy", "sphinx"] = "google"):
|
||||
with open(filename) as f:
|
||||
code = f.read()
|
||||
code = await WriteDocstring().run(code, style=style)
|
||||
if overwrite:
|
||||
with open(filename, "w") as f:
|
||||
f.write(code)
|
||||
return code
|
||||
|
||||
fire.Fire(run)
|
||||
169
metagpt/utils/pycst.py
Normal file
169
metagpt/utils/pycst.py
Normal file
|
|
@ -0,0 +1,169 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from typing import Union
|
||||
|
||||
import libcst as cst
|
||||
from libcst._nodes.module import Module
|
||||
|
||||
DocstringNode = Union[cst.Module, cst.ClassDef, cst.FunctionDef]
|
||||
|
||||
|
||||
def get_docstring_statement(body: DocstringNode) -> cst.SimpleStatementLine:
|
||||
"""Extracts the docstring from the body of a node.
|
||||
|
||||
Args:
|
||||
body: The body of a node.
|
||||
|
||||
Returns:
|
||||
The docstring statement if it exists, None otherwise.
|
||||
"""
|
||||
if isinstance(body, cst.Module):
|
||||
body = body.body
|
||||
else:
|
||||
body = body.body.body
|
||||
|
||||
if not body:
|
||||
return
|
||||
|
||||
statement = body[0]
|
||||
if not isinstance(statement, cst.SimpleStatementLine):
|
||||
return
|
||||
|
||||
expr = statement
|
||||
while isinstance(expr, (cst.BaseSuite, cst.SimpleStatementLine)):
|
||||
if len(expr.body) == 0:
|
||||
return None
|
||||
expr = expr.body[0]
|
||||
|
||||
if not isinstance(expr, cst.Expr):
|
||||
return None
|
||||
|
||||
val = expr.value
|
||||
if not isinstance(val, (cst.SimpleString, cst.ConcatenatedString)):
|
||||
return None
|
||||
|
||||
evaluated_value = val.evaluated_value
|
||||
if isinstance(evaluated_value, bytes):
|
||||
return None
|
||||
|
||||
return statement
|
||||
|
||||
|
||||
class DocstringCollector(cst.CSTVisitor):
|
||||
"""A visitor class for collecting docstrings from a CST.
|
||||
|
||||
Attributes:
|
||||
stack: A list to keep track of the current path in the CST.
|
||||
docstrings: A dictionary mapping paths in the CST to their corresponding docstrings.
|
||||
"""
|
||||
def __init__(self):
|
||||
self.stack: list[str] = []
|
||||
self.docstrings: dict[tuple[str, ...], cst.SimpleStatementLine] = {}
|
||||
|
||||
def visit_Module(self, node: cst.Module) -> bool | None:
|
||||
self.stack.append("")
|
||||
|
||||
def leave_Module(self, node: cst.Module) -> None:
|
||||
return self._leave(node)
|
||||
|
||||
def visit_ClassDef(self, node: cst.ClassDef) -> bool | None:
|
||||
self.stack.append(node.name.value)
|
||||
|
||||
def leave_ClassDef(self, node: cst.ClassDef) -> None:
|
||||
return self._leave(node)
|
||||
|
||||
def visit_FunctionDef(self, node: cst.FunctionDef) -> bool | None:
|
||||
self.stack.append(node.name.value)
|
||||
|
||||
def leave_FunctionDef(self, node: cst.FunctionDef) -> None:
|
||||
return self._leave(node)
|
||||
|
||||
def _leave(self, node: DocstringNode) -> None:
|
||||
key = tuple(self.stack)
|
||||
self.stack.pop()
|
||||
if hasattr(node, "decorators") and any(i.decorator.value == "overload" for i in node.decorators):
|
||||
return
|
||||
|
||||
statement = get_docstring_statement(node)
|
||||
if statement:
|
||||
self.docstrings[key] = statement
|
||||
|
||||
|
||||
class DocstringTransformer(cst.CSTTransformer):
|
||||
"""A transformer class for replacing docstrings in a CST.
|
||||
|
||||
Attributes:
|
||||
stack: A list to keep track of the current path in the CST.
|
||||
docstrings: A dictionary mapping paths in the CST to their corresponding docstrings.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
docstrings: dict[tuple[str, ...], cst.SimpleStatementLine],
|
||||
):
|
||||
self.stack: list[str] = []
|
||||
self.docstrings = docstrings
|
||||
|
||||
def visit_Module(self, node: cst.Module) -> bool | None:
|
||||
self.stack.append("")
|
||||
|
||||
def leave_Module(self, original_node: Module, updated_node: Module) -> Module:
|
||||
return self._leave(original_node, updated_node)
|
||||
|
||||
def visit_ClassDef(self, node: cst.ClassDef) -> bool | None:
|
||||
self.stack.append(node.name.value)
|
||||
|
||||
def leave_ClassDef(self, original_node: cst.ClassDef, updated_node: cst.ClassDef) -> cst.CSTNode:
|
||||
return self._leave(original_node, updated_node)
|
||||
|
||||
def visit_FunctionDef(self, node: cst.FunctionDef) -> bool | None:
|
||||
self.stack.append(node.name.value)
|
||||
|
||||
def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.CSTNode:
|
||||
return self._leave(original_node, updated_node)
|
||||
|
||||
def _leave(self, original_node: DocstringNode, updated_node: DocstringNode) -> DocstringNode:
|
||||
key = tuple(self.stack)
|
||||
self.stack.pop()
|
||||
|
||||
if hasattr(updated_node, "decorators") and any((i.decorator.value == "overload") for i in updated_node.decorators):
|
||||
return updated_node
|
||||
|
||||
statement = self.docstrings.get(key)
|
||||
if not statement:
|
||||
return updated_node
|
||||
|
||||
original_statement = get_docstring_statement(original_node)
|
||||
|
||||
if isinstance(updated_node, cst.Module):
|
||||
body = updated_node.body
|
||||
if original_statement:
|
||||
return updated_node.with_changes(body=(body[0], statement, *body[1:]))
|
||||
else:
|
||||
updated_node = updated_node.with_changes(body=(statement, cst.EmptyLine(), *body))
|
||||
return updated_node
|
||||
|
||||
body = updated_node.body.body
|
||||
if original_statement:
|
||||
return updated_node.with_changes(body=updated_node.body.with_changes(body=(body[0], statement, *body[1:])))
|
||||
else:
|
||||
return updated_node.with_changes(body=updated_node.body.with_changes(body=(statement, *body)))
|
||||
|
||||
|
||||
def merge_docstring(code: str, documented_code: str) -> str:
|
||||
"""Merges the docstrings from the documented code into the original code.
|
||||
|
||||
Args:
|
||||
code: The original code.
|
||||
documented_code: The documented code.
|
||||
|
||||
Returns:
|
||||
The original code with the docstrings from the documented code.
|
||||
"""
|
||||
code_tree = cst.parse_module(code)
|
||||
documented_code_tree = cst.parse_module(documented_code)
|
||||
|
||||
visitor = DocstringCollector()
|
||||
documented_code_tree.visit(visitor)
|
||||
transformer = DocstringTransformer(visitor.docstrings)
|
||||
modified_tree = code_tree.visit(transformer)
|
||||
return modified_tree.code
|
||||
|
|
@ -35,3 +35,4 @@ tqdm==4.64.0
|
|||
anthropic==0.3.6
|
||||
typing-inspect==0.8.0
|
||||
typing_extensions==4.5.0
|
||||
libcst==1.0.1
|
||||
|
|
|
|||
32
tests/metagpt/actions/test_write_docstring.py
Normal file
32
tests/metagpt/actions/test_write_docstring.py
Normal file
|
|
@ -0,0 +1,32 @@
|
|||
import pytest
|
||||
|
||||
from metagpt.actions.write_docstring import WriteDocstring
|
||||
|
||||
code = '''
|
||||
def add_numbers(a: int, b: int):
|
||||
return a + b
|
||||
|
||||
|
||||
class Person:
|
||||
def __init__(self, name: str, age: int):
|
||||
self.name = name
|
||||
self.age = age
|
||||
|
||||
def greet(self):
|
||||
return f"Hello, my name is {self.name} and I am {self.age} years old."
|
||||
'''
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
("style", "part"),
|
||||
[
|
||||
("google", "Args:"),
|
||||
("numpy", "Parameters"),
|
||||
("sphinx", ":param name:"),
|
||||
],
|
||||
ids=["google", "numpy", "sphinx"]
|
||||
)
|
||||
async def test_write_docstring(style: str, part: str):
|
||||
ret = await WriteDocstring().run(code, style=style)
|
||||
assert part in ret
|
||||
136
tests/metagpt/utils/test_pycst.py
Normal file
136
tests/metagpt/utils/test_pycst.py
Normal file
|
|
@ -0,0 +1,136 @@
|
|||
from metagpt.utils import pycst
|
||||
|
||||
code = '''
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
from typing import overload
|
||||
|
||||
@overload
|
||||
def add_numbers(a: int, b: int):
|
||||
...
|
||||
|
||||
@overload
|
||||
def add_numbers(a: float, b: float):
|
||||
...
|
||||
|
||||
def add_numbers(a: int, b: int):
|
||||
return a + b
|
||||
|
||||
|
||||
class Person:
|
||||
def __init__(self, name: str, age: int):
|
||||
self.name = name
|
||||
self.age = age
|
||||
|
||||
def greet(self):
|
||||
return f"Hello, my name is {self.name} and I am {self.age} years old."
|
||||
'''
|
||||
|
||||
documented_code = '''
|
||||
"""
|
||||
This is an example module containing a function and a class definition.
|
||||
"""
|
||||
|
||||
|
||||
def add_numbers(a: int, b: int):
|
||||
"""This function is used to add two numbers and return the result.
|
||||
|
||||
Parameters:
|
||||
a: The first integer.
|
||||
b: The second integer.
|
||||
|
||||
Returns:
|
||||
int: The sum of the two numbers.
|
||||
"""
|
||||
return a + b
|
||||
|
||||
class Person:
|
||||
"""This class represents a person's information, including name and age.
|
||||
|
||||
Attributes:
|
||||
name: The person's name.
|
||||
age: The person's age.
|
||||
"""
|
||||
|
||||
def __init__(self, name: str, age: int):
|
||||
"""Creates a new instance of the Person class.
|
||||
|
||||
Parameters:
|
||||
name: The person's name.
|
||||
age: The person's age.
|
||||
"""
|
||||
...
|
||||
|
||||
def greet(self):
|
||||
"""
|
||||
Returns a greeting message including the name and age.
|
||||
|
||||
Returns:
|
||||
str: The greeting message.
|
||||
"""
|
||||
...
|
||||
'''
|
||||
|
||||
|
||||
merged_code = '''
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
This is an example module containing a function and a class definition.
|
||||
"""
|
||||
|
||||
from typing import overload
|
||||
|
||||
@overload
|
||||
def add_numbers(a: int, b: int):
|
||||
...
|
||||
|
||||
@overload
|
||||
def add_numbers(a: float, b: float):
|
||||
...
|
||||
|
||||
def add_numbers(a: int, b: int):
|
||||
"""This function is used to add two numbers and return the result.
|
||||
|
||||
Parameters:
|
||||
a: The first integer.
|
||||
b: The second integer.
|
||||
|
||||
Returns:
|
||||
int: The sum of the two numbers.
|
||||
"""
|
||||
return a + b
|
||||
|
||||
|
||||
class Person:
|
||||
"""This class represents a person's information, including name and age.
|
||||
|
||||
Attributes:
|
||||
name: The person's name.
|
||||
age: The person's age.
|
||||
"""
|
||||
def __init__(self, name: str, age: int):
|
||||
"""Creates a new instance of the Person class.
|
||||
|
||||
Parameters:
|
||||
name: The person's name.
|
||||
age: The person's age.
|
||||
"""
|
||||
self.name = name
|
||||
self.age = age
|
||||
|
||||
def greet(self):
|
||||
"""
|
||||
Returns a greeting message including the name and age.
|
||||
|
||||
Returns:
|
||||
str: The greeting message.
|
||||
"""
|
||||
return f"Hello, my name is {self.name} and I am {self.age} years old."
|
||||
'''
|
||||
|
||||
|
||||
def test_merge_docstring():
|
||||
data = pycst.merge_docstring(code, documented_code)
|
||||
print(data)
|
||||
assert data == merged_code
|
||||
Loading…
Add table
Add a link
Reference in a new issue