mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-06-05 14:55:18 +02:00
add write docstring action
This commit is contained in:
parent
a5cb2fdd48
commit
bec5778dd0
5 changed files with 507 additions and 0 deletions
169
metagpt/utils/pycst.py
Normal file
169
metagpt/utils/pycst.py
Normal file
|
|
@ -0,0 +1,169 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from typing import Union
|
||||
|
||||
import libcst as cst
|
||||
from libcst._nodes.module import Module
|
||||
|
||||
DocstringNode = Union[cst.Module, cst.ClassDef, cst.FunctionDef]
|
||||
|
||||
|
||||
def get_docstring_statement(body: DocstringNode) -> cst.SimpleStatementLine:
|
||||
"""Extracts the docstring from the body of a node.
|
||||
|
||||
Args:
|
||||
body: The body of a node.
|
||||
|
||||
Returns:
|
||||
The docstring statement if it exists, None otherwise.
|
||||
"""
|
||||
if isinstance(body, cst.Module):
|
||||
body = body.body
|
||||
else:
|
||||
body = body.body.body
|
||||
|
||||
if not body:
|
||||
return
|
||||
|
||||
statement = body[0]
|
||||
if not isinstance(statement, cst.SimpleStatementLine):
|
||||
return
|
||||
|
||||
expr = statement
|
||||
while isinstance(expr, (cst.BaseSuite, cst.SimpleStatementLine)):
|
||||
if len(expr.body) == 0:
|
||||
return None
|
||||
expr = expr.body[0]
|
||||
|
||||
if not isinstance(expr, cst.Expr):
|
||||
return None
|
||||
|
||||
val = expr.value
|
||||
if not isinstance(val, (cst.SimpleString, cst.ConcatenatedString)):
|
||||
return None
|
||||
|
||||
evaluated_value = val.evaluated_value
|
||||
if isinstance(evaluated_value, bytes):
|
||||
return None
|
||||
|
||||
return statement
|
||||
|
||||
|
||||
class DocstringCollector(cst.CSTVisitor):
|
||||
"""A visitor class for collecting docstrings from a CST.
|
||||
|
||||
Attributes:
|
||||
stack: A list to keep track of the current path in the CST.
|
||||
docstrings: A dictionary mapping paths in the CST to their corresponding docstrings.
|
||||
"""
|
||||
def __init__(self):
|
||||
self.stack: list[str] = []
|
||||
self.docstrings: dict[tuple[str, ...], cst.SimpleStatementLine] = {}
|
||||
|
||||
def visit_Module(self, node: cst.Module) -> bool | None:
|
||||
self.stack.append("")
|
||||
|
||||
def leave_Module(self, node: cst.Module) -> None:
|
||||
return self._leave(node)
|
||||
|
||||
def visit_ClassDef(self, node: cst.ClassDef) -> bool | None:
|
||||
self.stack.append(node.name.value)
|
||||
|
||||
def leave_ClassDef(self, node: cst.ClassDef) -> None:
|
||||
return self._leave(node)
|
||||
|
||||
def visit_FunctionDef(self, node: cst.FunctionDef) -> bool | None:
|
||||
self.stack.append(node.name.value)
|
||||
|
||||
def leave_FunctionDef(self, node: cst.FunctionDef) -> None:
|
||||
return self._leave(node)
|
||||
|
||||
def _leave(self, node: DocstringNode) -> None:
|
||||
key = tuple(self.stack)
|
||||
self.stack.pop()
|
||||
if hasattr(node, "decorators") and any(i.decorator.value == "overload" for i in node.decorators):
|
||||
return
|
||||
|
||||
statement = get_docstring_statement(node)
|
||||
if statement:
|
||||
self.docstrings[key] = statement
|
||||
|
||||
|
||||
class DocstringTransformer(cst.CSTTransformer):
|
||||
"""A transformer class for replacing docstrings in a CST.
|
||||
|
||||
Attributes:
|
||||
stack: A list to keep track of the current path in the CST.
|
||||
docstrings: A dictionary mapping paths in the CST to their corresponding docstrings.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
docstrings: dict[tuple[str, ...], cst.SimpleStatementLine],
|
||||
):
|
||||
self.stack: list[str] = []
|
||||
self.docstrings = docstrings
|
||||
|
||||
def visit_Module(self, node: cst.Module) -> bool | None:
|
||||
self.stack.append("")
|
||||
|
||||
def leave_Module(self, original_node: Module, updated_node: Module) -> Module:
|
||||
return self._leave(original_node, updated_node)
|
||||
|
||||
def visit_ClassDef(self, node: cst.ClassDef) -> bool | None:
|
||||
self.stack.append(node.name.value)
|
||||
|
||||
def leave_ClassDef(self, original_node: cst.ClassDef, updated_node: cst.ClassDef) -> cst.CSTNode:
|
||||
return self._leave(original_node, updated_node)
|
||||
|
||||
def visit_FunctionDef(self, node: cst.FunctionDef) -> bool | None:
|
||||
self.stack.append(node.name.value)
|
||||
|
||||
def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.CSTNode:
|
||||
return self._leave(original_node, updated_node)
|
||||
|
||||
def _leave(self, original_node: DocstringNode, updated_node: DocstringNode) -> DocstringNode:
|
||||
key = tuple(self.stack)
|
||||
self.stack.pop()
|
||||
|
||||
if hasattr(updated_node, "decorators") and any((i.decorator.value == "overload") for i in updated_node.decorators):
|
||||
return updated_node
|
||||
|
||||
statement = self.docstrings.get(key)
|
||||
if not statement:
|
||||
return updated_node
|
||||
|
||||
original_statement = get_docstring_statement(original_node)
|
||||
|
||||
if isinstance(updated_node, cst.Module):
|
||||
body = updated_node.body
|
||||
if original_statement:
|
||||
return updated_node.with_changes(body=(body[0], statement, *body[1:]))
|
||||
else:
|
||||
updated_node = updated_node.with_changes(body=(statement, cst.EmptyLine(), *body))
|
||||
return updated_node
|
||||
|
||||
body = updated_node.body.body
|
||||
if original_statement:
|
||||
return updated_node.with_changes(body=updated_node.body.with_changes(body=(body[0], statement, *body[1:])))
|
||||
else:
|
||||
return updated_node.with_changes(body=updated_node.body.with_changes(body=(statement, *body)))
|
||||
|
||||
|
||||
def merge_docstring(code: str, documented_code: str) -> str:
|
||||
"""Merges the docstrings from the documented code into the original code.
|
||||
|
||||
Args:
|
||||
code: The original code.
|
||||
documented_code: The documented code.
|
||||
|
||||
Returns:
|
||||
The original code with the docstrings from the documented code.
|
||||
"""
|
||||
code_tree = cst.parse_module(code)
|
||||
documented_code_tree = cst.parse_module(documented_code)
|
||||
|
||||
visitor = DocstringCollector()
|
||||
documented_code_tree.visit(visitor)
|
||||
transformer = DocstringTransformer(visitor.docstrings)
|
||||
modified_tree = code_tree.visit(transformer)
|
||||
return modified_tree.code
|
||||
Loading…
Add table
Add a link
Reference in a new issue