diff --git a/metagpt/actions/write_docstring.py b/metagpt/actions/write_docstring.py new file mode 100644 index 000000000..091bd0a82 --- /dev/null +++ b/metagpt/actions/write_docstring.py @@ -0,0 +1,169 @@ +import ast +import contextlib +from typing import Literal + +from metagpt.actions.action import Action +from metagpt.utils.common import OutputParser +from metagpt.utils.pycst import merge_docstring + +PYTHON_DOCSTRING_SYSTEM = '''### Requirements +1. Add docstrings to the given code following the {style} style. +2. Remove all private members whose names start with an underscore, such as `_test` and `__init__`. +3. Replace the function body with an Ellipsis object(...) to reduce output. +4. If the types are already annotated, there is no need to include them in the docstring. +5. Only output Python code and avoid including any other text. + +### Input Example +```python +def function_with_pep484_type_annotations(param1: int) -> bool: + return isinstanc(param1, int) + +class ExampleError(Exception): + def __init__(self, msg: str): + self.msg = msg +``` + +### Output Example +```python{example}``` +''' + +# https://www.sphinx-doc.org/en/master/usage/extensions/napoleon.html + +PYTHON_DOCSTRING_EXAMPLE_GOOGLE = ''' +def function_with_pep484_type_annotations(param1: int) -> bool: + """Example function with PEP 484 type annotations. + + Extended description of function. + + Args: + param1: The first parameter. + + Returns: + The return value. True for success, False otherwise. + """ + ... + +class ExampleError(Exception): + """Exceptions are documented in the same way as classes. + + The __init__ method was documented in the class level docstring. + + Args: + msg: Human readable string describing the exception. + + Attributes: + msg: Human readable string describing the exception. + """ + ... +''' + +PYTHON_DOCSTRING_EXAMPLE_NUMPY = ''' +def function_with_pep484_type_annotations(param1: int) -> bool: + """ + Example function with PEP 484 type annotations. + + Extended description of function. + + Parameters + ---------- + param1 + The first parameter. + + Returns + ------- + bool + The return value. True for success, False otherwise. + """ + ... + +class ExampleError(Exception): + """ + Exceptions are documented in the same way as classes. + + The __init__ method was documented in the class level docstring. + + Parameters + ---------- + msg + Human readable string describing the exception. + + Attributes + ---------- + msg + Human readable string describing the exception. + """ + ... +''' + +PYTHON_DOCSTRING_EXAMPLE_SPHINX = ''' +def function_with_pep484_type_annotations(param1: int) -> bool: + """Example function with PEP 484 type annotations. + + Extended description of function. + + :param param1: The first parameter. + :type param1: int + + :return: The return value. True for success, False otherwise. + :rtype: bool + """ + ... + +class ExampleError(Exception): + """Exceptions are documented in the same way as classes. + + The __init__ method was documented in the class level docstring. + + :param msg: Human-readable string describing the exception. + :type msg: str + """ + ... +''' + +_python_docstring_style = { + "google": PYTHON_DOCSTRING_EXAMPLE_GOOGLE, + "numpy": PYTHON_DOCSTRING_EXAMPLE_NUMPY, + "sphinx": PYTHON_DOCSTRING_EXAMPLE_SPHINX, +} + + +class WriteDocstring(Action): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.desc = "Write docstring for code." + + async def run( + self, code: str, + system_text: str = PYTHON_DOCSTRING_SYSTEM, + style: Literal["google", "numpy", "sphinx"] = "google", + ) -> str: + system_text = system_text.format(style=style, example=_python_docstring_style[style]) + simplified_code = _simplify_python_code(code) + documented_code = await self._aask(simplified_code, [system_text]) + with contextlib.suppress(Exception): + documented_code = OutputParser.parse_code(documented_code) + return merge_docstring(code, documented_code) + + +def _simplify_python_code(code: str) -> None: + code_tree = ast.parse(code) + code_tree.body = [i for i in code_tree.body if not isinstance(i, ast.Expr)] + if isinstance(code_tree.body[-1], ast.If): + code_tree.body.pop() + return ast.unparse(code_tree) + + +if __name__ == "__main__": + import fire + + async def run(filename: str, overwrite: bool = False, style: Literal["google", "numpy", "sphinx"] = "google"): + with open(filename) as f: + code = f.read() + code = await WriteDocstring().run(code, style=style) + if overwrite: + with open(filename, "w") as f: + f.write(code) + return code + + fire.Fire(run) diff --git a/metagpt/utils/pycst.py b/metagpt/utils/pycst.py new file mode 100644 index 000000000..c2eb532ab --- /dev/null +++ b/metagpt/utils/pycst.py @@ -0,0 +1,169 @@ +from __future__ import annotations + +from typing import Union + +import libcst as cst +from libcst._nodes.module import Module + +DocstringNode = Union[cst.Module, cst.ClassDef, cst.FunctionDef] + + +def get_docstring_statement(body: DocstringNode) -> cst.SimpleStatementLine: + """Extracts the docstring from the body of a node. + + Args: + body: The body of a node. + + Returns: + The docstring statement if it exists, None otherwise. + """ + if isinstance(body, cst.Module): + body = body.body + else: + body = body.body.body + + if not body: + return + + statement = body[0] + if not isinstance(statement, cst.SimpleStatementLine): + return + + expr = statement + while isinstance(expr, (cst.BaseSuite, cst.SimpleStatementLine)): + if len(expr.body) == 0: + return None + expr = expr.body[0] + + if not isinstance(expr, cst.Expr): + return None + + val = expr.value + if not isinstance(val, (cst.SimpleString, cst.ConcatenatedString)): + return None + + evaluated_value = val.evaluated_value + if isinstance(evaluated_value, bytes): + return None + + return statement + + +class DocstringCollector(cst.CSTVisitor): + """A visitor class for collecting docstrings from a CST. + + Attributes: + stack: A list to keep track of the current path in the CST. + docstrings: A dictionary mapping paths in the CST to their corresponding docstrings. + """ + def __init__(self): + self.stack: list[str] = [] + self.docstrings: dict[tuple[str, ...], cst.SimpleStatementLine] = {} + + def visit_Module(self, node: cst.Module) -> bool | None: + self.stack.append("") + + def leave_Module(self, node: cst.Module) -> None: + return self._leave(node) + + def visit_ClassDef(self, node: cst.ClassDef) -> bool | None: + self.stack.append(node.name.value) + + def leave_ClassDef(self, node: cst.ClassDef) -> None: + return self._leave(node) + + def visit_FunctionDef(self, node: cst.FunctionDef) -> bool | None: + self.stack.append(node.name.value) + + def leave_FunctionDef(self, node: cst.FunctionDef) -> None: + return self._leave(node) + + def _leave(self, node: DocstringNode) -> None: + key = tuple(self.stack) + self.stack.pop() + if hasattr(node, "decorators") and any(i.decorator.value == "overload" for i in node.decorators): + return + + statement = get_docstring_statement(node) + if statement: + self.docstrings[key] = statement + + +class DocstringTransformer(cst.CSTTransformer): + """A transformer class for replacing docstrings in a CST. + + Attributes: + stack: A list to keep track of the current path in the CST. + docstrings: A dictionary mapping paths in the CST to their corresponding docstrings. + """ + def __init__( + self, + docstrings: dict[tuple[str, ...], cst.SimpleStatementLine], + ): + self.stack: list[str] = [] + self.docstrings = docstrings + + def visit_Module(self, node: cst.Module) -> bool | None: + self.stack.append("") + + def leave_Module(self, original_node: Module, updated_node: Module) -> Module: + return self._leave(original_node, updated_node) + + def visit_ClassDef(self, node: cst.ClassDef) -> bool | None: + self.stack.append(node.name.value) + + def leave_ClassDef(self, original_node: cst.ClassDef, updated_node: cst.ClassDef) -> cst.CSTNode: + return self._leave(original_node, updated_node) + + def visit_FunctionDef(self, node: cst.FunctionDef) -> bool | None: + self.stack.append(node.name.value) + + def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.CSTNode: + return self._leave(original_node, updated_node) + + def _leave(self, original_node: DocstringNode, updated_node: DocstringNode) -> DocstringNode: + key = tuple(self.stack) + self.stack.pop() + + if hasattr(updated_node, "decorators") and any((i.decorator.value == "overload") for i in updated_node.decorators): + return updated_node + + statement = self.docstrings.get(key) + if not statement: + return updated_node + + original_statement = get_docstring_statement(original_node) + + if isinstance(updated_node, cst.Module): + body = updated_node.body + if original_statement: + return updated_node.with_changes(body=(body[0], statement, *body[1:])) + else: + updated_node = updated_node.with_changes(body=(statement, cst.EmptyLine(), *body)) + return updated_node + + body = updated_node.body.body + if original_statement: + return updated_node.with_changes(body=updated_node.body.with_changes(body=(body[0], statement, *body[1:]))) + else: + return updated_node.with_changes(body=updated_node.body.with_changes(body=(statement, *body))) + + +def merge_docstring(code: str, documented_code: str) -> str: + """Merges the docstrings from the documented code into the original code. + + Args: + code: The original code. + documented_code: The documented code. + + Returns: + The original code with the docstrings from the documented code. + """ + code_tree = cst.parse_module(code) + documented_code_tree = cst.parse_module(documented_code) + + visitor = DocstringCollector() + documented_code_tree.visit(visitor) + transformer = DocstringTransformer(visitor.docstrings) + modified_tree = code_tree.visit(transformer) + return modified_tree.code diff --git a/requirements.txt b/requirements.txt index 32a436962..452e2d092 100644 --- a/requirements.txt +++ b/requirements.txt @@ -35,3 +35,4 @@ tqdm==4.64.0 anthropic==0.3.6 typing-inspect==0.8.0 typing_extensions==4.5.0 +libcst==1.0.1 diff --git a/tests/metagpt/actions/test_write_docstring.py b/tests/metagpt/actions/test_write_docstring.py new file mode 100644 index 000000000..82d96e1a6 --- /dev/null +++ b/tests/metagpt/actions/test_write_docstring.py @@ -0,0 +1,32 @@ +import pytest + +from metagpt.actions.write_docstring import WriteDocstring + +code = ''' +def add_numbers(a: int, b: int): + return a + b + + +class Person: + def __init__(self, name: str, age: int): + self.name = name + self.age = age + + def greet(self): + return f"Hello, my name is {self.name} and I am {self.age} years old." +''' + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + ("style", "part"), + [ + ("google", "Args:"), + ("numpy", "Parameters"), + ("sphinx", ":param name:"), + ], + ids=["google", "numpy", "sphinx"] +) +async def test_write_docstring(style: str, part: str): + ret = await WriteDocstring().run(code, style=style) + assert part in ret diff --git a/tests/metagpt/utils/test_pycst.py b/tests/metagpt/utils/test_pycst.py new file mode 100644 index 000000000..07352eac2 --- /dev/null +++ b/tests/metagpt/utils/test_pycst.py @@ -0,0 +1,136 @@ +from metagpt.utils import pycst + +code = ''' +#!/usr/bin/env python +# -*- coding: utf-8 -*- +from typing import overload + +@overload +def add_numbers(a: int, b: int): + ... + +@overload +def add_numbers(a: float, b: float): + ... + +def add_numbers(a: int, b: int): + return a + b + + +class Person: + def __init__(self, name: str, age: int): + self.name = name + self.age = age + + def greet(self): + return f"Hello, my name is {self.name} and I am {self.age} years old." +''' + +documented_code = ''' +""" +This is an example module containing a function and a class definition. +""" + + +def add_numbers(a: int, b: int): + """This function is used to add two numbers and return the result. + + Parameters: + a: The first integer. + b: The second integer. + + Returns: + int: The sum of the two numbers. + """ + return a + b + +class Person: + """This class represents a person's information, including name and age. + + Attributes: + name: The person's name. + age: The person's age. + """ + + def __init__(self, name: str, age: int): + """Creates a new instance of the Person class. + + Parameters: + name: The person's name. + age: The person's age. + """ + ... + + def greet(self): + """ + Returns a greeting message including the name and age. + + Returns: + str: The greeting message. + """ + ... +''' + + +merged_code = ''' +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +This is an example module containing a function and a class definition. +""" + +from typing import overload + +@overload +def add_numbers(a: int, b: int): + ... + +@overload +def add_numbers(a: float, b: float): + ... + +def add_numbers(a: int, b: int): + """This function is used to add two numbers and return the result. + + Parameters: + a: The first integer. + b: The second integer. + + Returns: + int: The sum of the two numbers. + """ + return a + b + + +class Person: + """This class represents a person's information, including name and age. + + Attributes: + name: The person's name. + age: The person's age. + """ + def __init__(self, name: str, age: int): + """Creates a new instance of the Person class. + + Parameters: + name: The person's name. + age: The person's age. + """ + self.name = name + self.age = age + + def greet(self): + """ + Returns a greeting message including the name and age. + + Returns: + str: The greeting message. + """ + return f"Hello, my name is {self.name} and I am {self.age} years old." +''' + + +def test_merge_docstring(): + data = pycst.merge_docstring(code, documented_code) + print(data) + assert data == merged_code