feat: +google style docstring

This commit is contained in:
莘权 马 2024-02-01 17:43:18 +08:00
parent 6b527e3020
commit 027f1e8658
6 changed files with 757 additions and 55 deletions

View file

@ -4,11 +4,11 @@
@Time : 2023/12/19
@Author : mashenquan
@File : rebuild_class_view.py
@Desc : Rebuild class view info
@Desc : Reconstructs class diagram from a source code project.
"""
from pathlib import Path
from typing import Optional
from typing import Optional, Set, Tuple
import aiofiles
@ -30,9 +30,26 @@ from metagpt.utils.graph_repository import GraphKeyword, GraphRepository
class RebuildClassView(Action):
"""
Reconstructs a graph repository about class diagram from a source code project.
Attributes:
graph_db (Optional[GraphRepository]): The optional graph repository.
"""
graph_db: Optional[GraphRepository] = None
async def run(self, with_messages=None, format=config.prompt_schema):
"""
Implementation of `Action`'s `run` method.
Args:
with_messages (Optional[Type]): An optional argument specifying messages to react to.
format (str): The format for the prompt schema.
Returns:
None
"""
graph_repo_pathname = self.context.git_repo.workdir / GRAPH_REPO_FILE_REPO / self.context.git_repo.workdir.name
self.graph_db = await DiGraphRepository.load_from(str(graph_repo_pathname.with_suffix(".json")))
repo_parser = RepoParser(base_directory=Path(self.i_context))
@ -52,6 +69,13 @@ class RebuildClassView(Action):
await self.graph_db.save()
async def _create_mermaid_class_views(self):
"""Creates a Mermaid class diagram using data from the `graph_db` graph repository.
This method utilizes information stored in the graph repository to generate a Mermaid class diagram.
Returns:
None
"""
path = Path(self.context.git_repo.workdir) / DATA_API_DESIGN_FILE_REPO
path.mkdir(parents=True, exist_ok=True)
pathname = path / self.context.git_repo.workdir.name
@ -77,6 +101,14 @@ class RebuildClassView(Action):
logger.info(f"classes: {len(class_distinct)}, relationship: {len(relationship_distinct)}")
async def _create_mermaid_class(self, ns_class_name) -> str:
"""Generates a Mermaid class diagram for a specific class using data from the `graph_db` graph repository.
Args:
ns_class_name (str): The namespace-prefixed name of the class for which the Mermaid class diagram is to be created.
Returns:
str: A Mermaid code block object in markdown representing the class diagram.
"""
fields = split_namespace(ns_class_name)
if len(fields) > 2:
# Ignore sub-class
@ -110,11 +142,19 @@ class RebuildClassView(Action):
logger.debug(content)
return content
async def _create_mermaid_relationship(self, ns_class_name):
async def _create_mermaid_relationship(self, ns_class_name: str) -> Tuple[str, Set]:
"""Generates a Mermaid class relationship diagram for a specific class using data from the `graph_db` graph repository.
Args:
ns_class_name (str): The namespace-prefixed class name for which the Mermaid relationship diagram is to be created.
Returns:
Tuple[str, Set]: A tuple containing the relationship diagram as a string and a set of deduplication.
"""
s_fields = split_namespace(ns_class_name)
if len(s_fields) > 2:
# Ignore sub-class
return
return None, None
predicates = {GraphKeyword.IS + v + GraphKeyword.OF: v for v in [GENERALIZATION, COMPOSITION, AGGREGATION]}
mappings = {
@ -140,6 +180,15 @@ class RebuildClassView(Action):
@staticmethod
def _diff_path(path_root: Path, package_root: Path) -> (str, str):
"""Returns the difference between the root path and the path information represented in the package name.
Args:
path_root (Path): The root path.
package_root (Path): The package root path.
Returns:
Tuple[str, str]: A tuple containing the representation of the difference ("+", "-", "=") and the path detail of the differing part.
"""
if len(str(path_root)) > len(str(package_root)):
return "+", str(path_root.relative_to(package_root))
if len(str(path_root)) < len(str(package_root)):
@ -147,7 +196,17 @@ class RebuildClassView(Action):
return "=", "."
@staticmethod
def _align_root(path: str, direction: str, diff_path: str):
def _align_root(path: str, direction: str, diff_path: str) -> str:
"""Aligns the path to the same root represented by `diff_path`.
Args:
path (str): The path to be aligned.
direction (str): The direction of alignment ('+', '-', '=').
diff_path (str): The path representing the difference.
Returns:
str: The aligned path.
"""
if direction == "=":
return path
if direction == "+":

View file

@ -4,7 +4,7 @@
@Time : 2024/1/4
@Author : mashenquan
@File : rebuild_sequence_view.py
@Desc : Rebuild sequence view info
@Desc : Reconstruct sequence view information through reverse engineering.
"""
from __future__ import annotations
@ -37,7 +37,19 @@ from metagpt.utils.di_graph_repository import DiGraphRepository
from metagpt.utils.graph_repository import SPO, GraphKeyword, GraphRepository
class SQVUseCase(BaseModel):
class ReverseUseCase(BaseModel):
"""
Represents a reverse engineered use case.
Attributes:
description (str): A description of the reverse use case.
inputs (List[str]): List of inputs for the reverse use case.
outputs (List[str]): List of outputs for the reverse use case.
actors (List[str]): List of actors involved in the reverse use case.
steps (List[str]): List of steps for the reverse use case.
reason (str): The reason behind the reverse use case.
"""
description: str
inputs: List[str]
outputs: List[str]
@ -46,16 +58,42 @@ class SQVUseCase(BaseModel):
reason: str
class SQVUseCaseDetails(BaseModel):
class ReverseUseCaseDetails(BaseModel):
"""
Represents details of a reverse engineered use case.
Attributes:
description (str): A description of the reverse use case details.
use_cases (List[ReverseUseCase]): List of reverse use cases.
relationship (List[str]): List of relationships associated with the reverse use case details.
"""
description: str
use_cases: List[SQVUseCase]
use_cases: List[ReverseUseCase]
relationship: List[str]
class RebuildSequenceView(Action):
"""
Represents an action to reconstruct sequence view through reverse engineering.
Attributes:
graph_db (Optional[GraphRepository]): An optional instance of GraphRepository for graph database operations.
"""
graph_db: Optional[GraphRepository] = None
async def run(self, with_messages=None, format=config.prompt_schema):
"""
Implementation of `Action`'s `run` method.
Args:
with_messages (Optional[Type]): An optional argument specifying messages to react to.
format (str): The format for the prompt schema.
Returns:
None
"""
graph_repo_pathname = self.context.git_repo.workdir / GRAPH_REPO_FILE_REPO / self.context.git_repo.workdir.name
self.graph_db = await DiGraphRepository.load_from(str(graph_repo_pathname.with_suffix(".json")))
entries = await self._search_main_entry()
@ -65,12 +103,22 @@ class RebuildSequenceView(Action):
pass
await self.graph_db.save()
# @retry(
# wait=wait_random_exponential(min=1, max=20),
# stop=stop_after_attempt(6),
# after=general_after_log(logger),
# )
async def _rebuild_main_sequence_view(self, entry):
@retry(
wait=wait_random_exponential(min=1, max=20),
stop=stop_after_attempt(6),
after=general_after_log(logger),
)
async def _rebuild_main_sequence_view(self, entry: SPO):
"""
Reconstruct the sequence diagram for the __main__ entry of the source code through reverse engineering.
Args:
entry (SPO): The SPO (Subject, Predicate, Object) object in the graph database that is related to the
subject `__name__:__main__`.
Returns:
None
"""
filename = entry.subject.split(":", 1)[0]
rows = await self.graph_db.select(predicate=GraphKeyword.IS, object_=GraphKeyword.CLASS)
classes = []
@ -139,7 +187,16 @@ class RebuildSequenceView(Action):
)
await self.graph_db.save()
async def _merge_sequence_view(self, entry) -> bool:
async def _merge_sequence_view(self, entry: SPO) -> bool:
"""
Augments additional information to the provided SPO (Subject, Predicate, Object) entry in the sequence diagram.
Args:
entry (SPO): The SPO object representing the relationship in the graph database.
Returns:
bool: True if additional information has been augmented, otherwise False.
"""
new_participant = await self._search_new_participant(entry)
if not new_participant:
return False
@ -148,6 +205,12 @@ class RebuildSequenceView(Action):
return True
async def _search_main_entry(self) -> List:
"""
Asynchronously searches for the SPO object that is related to `__name__:__main__`.
Returns:
List: A list containing information about the main entry in the sequence diagram.
"""
rows = await self.graph_db.select(predicate=GraphKeyword.HAS_PAGE_INFO)
tag = "__name__:__main__"
entries = []
@ -161,7 +224,16 @@ class RebuildSequenceView(Action):
stop=stop_after_attempt(6),
after=general_after_log(logger),
)
async def _rebuild_use_case(self, ns_class_name):
async def _rebuild_use_case(self, ns_class_name: str):
"""
Asynchronously reconstructs the use case for the provided namespace-prefixed class name.
Args:
ns_class_name (str): The namespace-prefixed class name for which the use case is to be reconstructed.
Returns:
None
"""
rows = await self.graph_db.select(subject=ns_class_name, predicate=GraphKeyword.HAS_CLASS_USE_CASE)
if rows:
return
@ -195,20 +267,25 @@ class RebuildSequenceView(Action):
system_msgs=[
"You are a python code to UML 2.0 Use Case translator.",
'The generated UML 2.0 Use Case must include the roles or entities listed in "Participants".',
'The functional descriptions of Actors and Use Cases in the generated UML 2.0 Use Case must not conflict with the information in "Mermaid Class Views".',
# 'Only steps that involve input, output, and interactive operations with the external system at the same time can be considered as independent use cases.',
# "Only steps that involve input, output, and interactive operations with the external system at the same time can be considered as independent use cases, steps that do not meet any one condition should be incorporated into other use cases.",
'The section under `if __name__ == "__main__":` of "Source Code" contains information about external system interactions with the internal system.',
"The functional descriptions of Actors and Use Cases in the generated UML 2.0 Use Case must not "
'conflict with the information in "Mermaid Class Views".',
'The section under `if __name__ == "__main__":` of "Source Code" contains information about external '
"system interactions with the internal system.",
"Return a markdown JSON object with:\n"
'- a "description" key to explain what the whole source code want to do;\n'
'- a "use_cases" key list all use cases, each use case in the list should including a `description` key describes about what the use case to do, a `inputs` key lists the input names of the use case from external sources, a `outputs` key lists the output names of the use case to external sources, a `actors` key lists the participant actors of the use case, a `steps` key lists the steps about how the use case works step by step, a `reason` key explaining under what circumstances would the external system execute this use case.\n'
'- a "use_cases" key list all use cases, each use case in the list should including a `description` '
"key describes about what the use case to do, a `inputs` key lists the input names of the use case "
"from external sources, a `outputs` key lists the output names of the use case to external sources, "
"a `actors` key lists the participant actors of the use case, a `steps` key lists the steps about how "
"the use case works step by step, a `reason` key explaining under what circumstances would the "
"external system execute this use case.\n"
'- a "relationship" key lists all the descriptions of relationship among these use cases.\n',
],
)
code_blocks = parse_json_code_block(rsp)
for block in code_blocks:
detail = SQVUseCaseDetails.model_validate_json(block)
detail = ReverseUseCaseDetails.model_validate_json(block)
await self.graph_db.insert(
subject=ns_class_name, predicate=GraphKeyword.HAS_CLASS_USE_CASE, object_=detail.model_dump_json()
)
@ -219,7 +296,16 @@ class RebuildSequenceView(Action):
stop=stop_after_attempt(6),
after=general_after_log(logger),
)
async def _rebuild_sequence_view(self, ns_class_name):
async def _rebuild_sequence_view(self, ns_class_name: str):
"""
Asynchronously reconstructs the sequence diagram for the provided namespace-prefixed class name.
Args:
ns_class_name (str): The namespace-prefixed class name for which the sequence diagram is to be reconstructed.
Returns:
None
"""
await self._rebuild_use_case(ns_class_name)
prompts_blocks = []
@ -262,7 +348,17 @@ class RebuildSequenceView(Action):
)
await self.graph_db.save()
async def _get_participants(self, ns_class_name) -> List[str]:
async def _get_participants(self, ns_class_name: str) -> List[str]:
"""
Asynchronously returns the participants list of the sequence diagram for the provided namespace-prefixed SPO
object.
Args:
ns_class_name (str): The namespace-prefixed class name for which to retrieve the participants list.
Returns:
List[str]: A list of participants in the sequence diagram.
"""
participants = set()
detail = await self._get_class_detail(ns_class_name)
if not detail:
@ -271,11 +367,20 @@ class RebuildSequenceView(Action):
participants.update(set(detail.aggregations))
return list(participants)
async def _get_class_use_cases(self, ns_class_name) -> str:
async def _get_class_use_cases(self, ns_class_name: str) -> str:
"""
Asynchronously assembles the context about the use case information of the namespace-prefixed SPO object.
Args:
ns_class_name (str): The namespace-prefixed class name for which to retrieve use case information.
Returns:
str: A string containing the assembled context about the use case information.
"""
block = ""
rows = await self.graph_db.select(subject=ns_class_name, predicate=GraphKeyword.HAS_CLASS_USE_CASE)
for i, r in enumerate(rows):
detail = SQVUseCaseDetails.model_validate_json(r.object_)
detail = ReverseUseCaseDetails.model_validate_json(r.object_)
block += f"\n### {i + 1}. {detail.description}"
for j, use_case in enumerate(detail.use_cases):
block += f"\n#### {i + 1}.{j + 1}. {use_case.description}\n"
@ -286,21 +391,50 @@ class RebuildSequenceView(Action):
block += "\n#### Use Case Relationship\n" + "\n".join([f"- {s}" for s in detail.relationship])
return block + "\n"
async def _get_class_detail(self, ns_class_name) -> DotClassInfo | None:
async def _get_class_detail(self, ns_class_name: str) -> DotClassInfo | None:
"""
Asynchronously retrieves the dot format class details of the namespace-prefixed SPO object.
Args:
ns_class_name (str): The namespace-prefixed class name for which to retrieve class details.
Returns:
Union[DotClassInfo, None]: A DotClassInfo object representing the dot format class details,
or None if the details are not available.
"""
rows = await self.graph_db.select(subject=ns_class_name, predicate=GraphKeyword.HAS_DETAIL)
if not rows:
return None
dot_class_info = DotClassInfo.model_validate_json(rows[0].object_)
return dot_class_info
async def _get_uml_class_view(self, ns_class_name) -> UMLClassView | None:
async def _get_uml_class_view(self, ns_class_name: str) -> UMLClassView | None:
"""
Asynchronously retrieves the UML 2.0 format class details of the namespace-prefixed SPO object.
Args:
ns_class_name (str): The namespace-prefixed class name for which to retrieve UML class details.
Returns:
Union[UMLClassView, None]: A UMLClassView object representing the UML 2.0 format class details,
or None if the details are not available.
"""
rows = await self.graph_db.select(subject=ns_class_name, predicate=GraphKeyword.HAS_CLASS_VIEW)
if not rows:
return None
class_view = UMLClassView.model_validate_json(rows[0].object_)
return class_view
async def _get_source_code(self, ns_class_name) -> str:
async def _get_source_code(self, ns_class_name: str) -> str:
"""
Asynchronously retrieves the source code of the namespace-prefixed SPO object.
Args:
ns_class_name (str): The namespace-prefixed class name for which to retrieve the source code.
Returns:
str: A string containing the source code of the specified namespace-prefixed class.
"""
rows = await self.graph_db.select(subject=ns_class_name, predicate=GraphKeyword.HAS_PAGE_INFO)
filename = split_namespace(ns_class_name=ns_class_name)[0]
if not rows:
@ -315,6 +449,16 @@ class RebuildSequenceView(Action):
@staticmethod
def _get_full_filename(root: str | Path, pathname: str | Path) -> Path | None:
"""
Convert package name to the full path of the module.
Args:
root (Union[str, Path]): The root path or string representing the package.
pathname (Union[str, Path]): The pathname or string representing the module.
Returns:
Union[Path, None]: The full path of the module, or None if the path cannot be determined.
"""
files = list_files(root=root)
postfix = "/" + str(pathname)
for i in files:
@ -324,11 +468,30 @@ class RebuildSequenceView(Action):
@staticmethod
def parse_participant(mermaid_sequence_diagram: str) -> List[str]:
"""
Parses the provided Mermaid sequence diagram and returns the list of participants.
Args:
mermaid_sequence_diagram (str): The Mermaid sequence diagram string to be parsed.
Returns:
List[str]: A list of participants extracted from the sequence diagram.
"""
pattern = r"participant ([a-zA-Z\.0-9_]+)"
matches = re.findall(pattern, mermaid_sequence_diagram)
matches = [re.sub(r"[\\/'\"]+", "", i) for i in matches]
return matches
async def _search_new_participant(self, entry: SPO) -> str | None:
"""
Asynchronously retrieves a participant whose sequence diagram has not been augmented.
Args:
entry (SPO): The SPO object representing the relationship in the graph database.
Returns:
Union[str, None]: A participant whose sequence diagram has not been augmented, or None if not found.
"""
rows = await self.graph_db.select(subject=entry.subject, predicate=GraphKeyword.HAS_SEQUENCE_VIEW)
if not rows:
return None
@ -351,6 +514,16 @@ class RebuildSequenceView(Action):
after=general_after_log(logger),
)
async def _merge_participant(self, entry: SPO, class_name: str):
"""
Augments the sequence diagram of `class_name` to the sequence diagram of `entry`.
Args:
entry (SPO): The SPO object representing the base sequence diagram.
class_name (str): The class name whose sequence diagram is to be augmented.
Returns:
None
"""
rows = await self.graph_db.select(predicate=GraphKeyword.IS, object_=GraphKeyword.CLASS)
participants = []
for r in rows:

View file

@ -1,6 +1,10 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Build a symbols repository from source code.
This script is designed to create a symbols repository from the provided source code.
@Time : 2023/11/17 17:58
@Author : alexanderwu
@File : repo_parser.py
@ -24,6 +28,17 @@ from metagpt.utils.exceptions import handle_exception
class RepoFileInfo(BaseModel):
"""
Repository data element that represents information about a file.
Attributes:
file (str): The name or path of the file.
classes (List): A list of class names present in the file.
functions (List): A list of function names present in the file.
globals (List): A list of global variable names present in the file.
page_info (List): A list of page-related information associated with the file.
"""
file: str
classes: List = Field(default_factory=list)
functions: List = Field(default_factory=list)
@ -32,6 +47,17 @@ class RepoFileInfo(BaseModel):
class CodeBlockInfo(BaseModel):
"""
Repository data element representing information about a code block.
Attributes:
lineno (int): The starting line number of the code block.
end_lineno (int): The ending line number of the code block.
type_name (str): The type or category of the code block.
tokens (List): A list of tokens present in the code block.
properties (Dict): A dictionary containing additional properties associated with the code block.
"""
lineno: int
end_lineno: int
type_name: str
@ -40,6 +66,17 @@ class CodeBlockInfo(BaseModel):
class DotClassAttribute(BaseModel):
"""
Repository data element representing a class attribute in dot format.
Attributes:
name (str): The name of the class attribute.
type_ (str): The type of the class attribute.
default_ (str): The default value of the class attribute.
description (str): A description of the class attribute.
compositions (List[str]): A list of compositions associated with the class attribute.
"""
name: str = ""
type_: str = ""
default_: str = ""
@ -48,6 +85,15 @@ class DotClassAttribute(BaseModel):
@classmethod
def parse(cls, v: str) -> "DotClassAttribute":
"""
Parses dot format text and returns a DotClassAttribute object.
Args:
v (str): Dot format text to be parsed.
Returns:
DotClassAttribute: An instance of the DotClassAttribute class representing the parsed data.
"""
val = ""
meet_colon = False
meet_equals = False
@ -89,10 +135,28 @@ class DotClassAttribute(BaseModel):
@staticmethod
def remove_white_spaces(v: str):
"""
Removes white spaces from the provided string, excluding spaces within quotes.
Args:
v (str): The input string containing white spaces.
Returns:
str: The input string with white spaces removed, excluding spaces within quotes.
"""
return re.sub(r"(?<!['\"])\s|(?<=['\"])\s", "", v)
@staticmethod
def parse_compositions(types_part) -> List[str]:
"""
Parses the type definition code block of source code and returns a list of compositions.
Args:
types_part: The type definition code block to be parsed.
Returns:
List[str]: A list of compositions extracted from the type definition code block.
"""
if not types_part:
return []
modified_string = re.sub(r"[\[\],\(\)]", "|", types_part)
@ -128,6 +192,15 @@ class DotClassAttribute(BaseModel):
@staticmethod
def _split_literal(v):
"""
Parses the literal definition code block and returns three parts: pre-part, literal-part, and post-part.
Args:
v: The literal definition code block to be parsed.
Returns:
Tuple[str, str, str]: A tuple containing the pre-part, literal-part, and post-part of the code block.
"""
tag = "Literal["
bix = v.find(tag)
eix = len(v) - 1
@ -153,11 +226,32 @@ class DotClassAttribute(BaseModel):
@field_validator("compositions", mode="after")
@classmethod
def sort(cls, lst: List) -> List:
"""
Auto-sorts a list attribute after making changes.
Args:
lst (List): The list attribute to be sorted.
Returns:
List: The sorted list.
"""
lst.sort()
return lst
class DotClassInfo(BaseModel):
"""
Repository data element representing information about a class in dot format.
Attributes:
name (str): The name of the class.
package (Optional[str]): The package to which the class belongs (optional).
attributes (Dict[str, DotClassAttribute]): A dictionary of attributes associated with the class.
methods (Dict[str, DotClassMethod]): A dictionary of methods associated with the class.
compositions (List[str]): A list of compositions associated with the class.
aggregations (List[str]): A list of aggregations associated with the class.
"""
name: str
package: Optional[str] = None
attributes: Dict[str, DotClassAttribute] = Field(default_factory=dict)
@ -168,11 +262,30 @@ class DotClassInfo(BaseModel):
@field_validator("compositions", "aggregations", mode="after")
@classmethod
def sort(cls, lst: List) -> List:
"""
Auto-sorts a list attribute after making changes.
Args:
lst (List): The list attribute to be sorted.
Returns:
List: The sorted list.
"""
lst.sort()
return lst
class DotClassRelationship(BaseModel):
"""
Repository data element representing a relationship between two classes in dot format.
Attributes:
src (str): The source class of the relationship.
dest (str): The destination class of the relationship.
relationship (str): The type or nature of the relationship.
label (Optional[str]): An optional label associated with the relationship.
"""
src: str = ""
dest: str = ""
relationship: str = ""
@ -180,12 +293,31 @@ class DotClassRelationship(BaseModel):
class DotReturn(BaseModel):
"""
Repository data element representing a function or method return type in dot format.
Attributes:
type_ (str): The type of the return.
description (str): A description of the return type.
compositions (List[str]): A list of compositions associated with the return type.
"""
type_: str = ""
description: str
compositions: List[str] = Field(default_factory=list)
@classmethod
def parse(cls, v: str) -> "DotReturn" | None:
"""
Parses the return type part of dot format text and returns a DotReturn object.
Args:
v (str): The dot format text containing the return type part to be parsed.
Returns:
DotReturn | None: An instance of the DotReturn class representing the parsed return type,
or None if parsing fails.
"""
if not v:
return DotReturn(description=v)
type_ = DotClassAttribute.remove_white_spaces(v)
@ -195,6 +327,15 @@ class DotReturn(BaseModel):
@field_validator("compositions", mode="after")
@classmethod
def sort(cls, lst: List) -> List:
"""
Auto-sorts a list attribute after making changes.
Args:
lst (List): The list attribute to be sorted.
Returns:
List: The sorted list.
"""
lst.sort()
return lst
@ -208,6 +349,15 @@ class DotClassMethod(BaseModel):
@classmethod
def parse(cls, v: str) -> "DotClassMethod":
"""
Parses a dot format method text and returns a DotClassMethod object.
Args:
v (str): The dot format text containing method information to be parsed.
Returns:
DotClassMethod: An instance of the DotClassMethod class representing the parsed method.
"""
bix = v.find("(")
eix = v.rfind(")")
rix = v.rfind(":")
@ -229,6 +379,15 @@ class DotClassMethod(BaseModel):
@staticmethod
def _parse_name(v: str) -> str:
"""
Parses the dot format method name part and returns the method name.
Args:
v (str): The dot format text containing the method name part to be parsed.
Returns:
str: The parsed method name.
"""
tags = [">", "</"]
if tags[0] in v:
bix = v.find(tags[0]) + len(tags[0])
@ -238,6 +397,15 @@ class DotClassMethod(BaseModel):
@staticmethod
def _parse_args(v: str) -> List[DotClassAttribute]:
"""
Parses the dot format method arguments part and returns the parsed arguments.
Args:
v (str): The dot format text containing the arguments part to be parsed.
Returns:
str: The parsed method arguments.
"""
if not v:
return []
parts = []
@ -265,16 +433,40 @@ class DotClassMethod(BaseModel):
class RepoParser(BaseModel):
"""
Tool to build a symbols repository from a project directory.
Attributes:
base_directory (Path): The base directory of the project.
"""
base_directory: Path = Field(default=None)
@classmethod
@handle_exception(exception_type=Exception, default_return=[])
def _parse_file(cls, file_path: Path) -> list:
"""Parse a Python file in the repository."""
"""
Parses a Python file in the repository.
Args:
file_path (Path): The path to the Python file to be parsed.
Returns:
list: A list containing the parsed symbols from the file.
"""
return ast.parse(file_path.read_text()).body
def extract_class_and_function_info(self, tree, file_path) -> RepoFileInfo:
"""Extract class, function, and global variable information from the AST."""
"""
Extracts class, function, and global variable information from the Abstract Syntax Tree (AST).
Args:
tree: The Abstract Syntax Tree (AST) of the Python file.
file_path: The path to the Python file.
Returns:
RepoFileInfo: A RepoFileInfo object containing the extracted information.
"""
file_info = RepoFileInfo(file=str(file_path.relative_to(self.base_directory)))
for node in tree:
info = RepoParser.node_to_str(node)
@ -292,6 +484,12 @@ class RepoParser(BaseModel):
return file_info
def generate_symbols(self) -> List[RepoFileInfo]:
"""
Builds a symbol repository from '.py' and '.js' files in the project directory.
Returns:
List[RepoFileInfo]: A list of RepoFileInfo objects containing the extracted information.
"""
files_classes = []
directory = self.base_directory
@ -306,19 +504,38 @@ class RepoParser(BaseModel):
return files_classes
def generate_json_structure(self, output_path):
"""Generate a JSON file documenting the repository structure."""
def generate_json_structure(self, output_path: Path):
"""
Generates a JSON file documenting the repository structure.
Args:
output_path (Path): The path to the JSON file to be generated.
"""
files_classes = [i.model_dump() for i in self.generate_symbols()]
output_path.write_text(json.dumps(files_classes, indent=4))
def generate_dataframe_structure(self, output_path):
"""Generate a DataFrame documenting the repository structure and save as CSV."""
def generate_dataframe_structure(self, output_path: Path):
"""
Generates a DataFrame documenting the repository structure and saves it as a CSV file.
Args:
output_path (Path): The path to the CSV file to be generated.
"""
files_classes = [i.model_dump() for i in self.generate_symbols()]
df = pd.DataFrame(files_classes)
df.to_csv(output_path, index=False)
def generate_structure(self, output_path=None, mode="json") -> Path:
"""Generate the structure of the repository as a specified format."""
def generate_structure(self, output_path: str | Path = None, mode="json") -> Path:
"""
Generates the structure of the repository in a specified format.
Args:
output_path (str | Path): The path to the output file or directory. Default is None.
mode (str): The output format mode. Options: "json" (default), "csv", etc.
Returns:
Path: The path to the generated output file or directory.
"""
output_file = self.base_directory / f"{self.base_directory.name}-structure.{mode}"
output_path = Path(output_path) if output_path else output_file
@ -330,6 +547,16 @@ class RepoParser(BaseModel):
@staticmethod
def node_to_str(node) -> CodeBlockInfo | None:
"""
Parses and converts an Abstract Syntax Tree (AST) node to a CodeBlockInfo object.
Args:
node: The AST node to be converted.
Returns:
CodeBlockInfo | None: A CodeBlockInfo object representing the parsed AST node,
or None if the conversion fails.
"""
if isinstance(node, ast.Try):
return None
if any_to_str(node) == any_to_str(ast.Expr):
@ -370,6 +597,15 @@ class RepoParser(BaseModel):
@staticmethod
def _parse_expr(node) -> List:
"""
Parses an expression Abstract Syntax Tree (AST) node.
Args:
node: The AST node representing an expression.
Returns:
List: A list containing the parsed information from the expression node.
"""
funcs = {
any_to_str(ast.Constant): lambda x: [any_to_str(x.value), RepoParser._parse_variable(x.value)],
any_to_str(ast.Call): lambda x: [any_to_str(x.value), RepoParser._parse_variable(x.value.func)],
@ -381,12 +617,30 @@ class RepoParser(BaseModel):
@staticmethod
def _parse_name(n):
"""
Gets the 'name' value of an Abstract Syntax Tree (AST) node.
Args:
n: The AST node.
Returns:
The 'name' value of the AST node.
"""
if n.asname:
return f"{n.name} as {n.asname}"
return n.name
@staticmethod
def _parse_if(n):
"""
Parses an 'if' statement Abstract Syntax Tree (AST) node.
Args:
n: The AST node representing an 'if' statement.
Returns:
None or Parsed information from the 'if' statement node.
"""
tokens = []
try:
if isinstance(n.test, ast.BoolOp):
@ -409,6 +663,15 @@ class RepoParser(BaseModel):
@staticmethod
def _parse_if_compare(n):
"""
Parses an 'if' condition Abstract Syntax Tree (AST) node.
Args:
n: The AST node representing an 'if' condition.
Returns:
None or Parsed information from the 'if' condition node.
"""
if hasattr(n, "left"):
return RepoParser._parse_variable(n.left)
else:
@ -416,6 +679,15 @@ class RepoParser(BaseModel):
@staticmethod
def _parse_variable(node):
"""
Parses a variable Abstract Syntax Tree (AST) node.
Args:
node: The AST node representing a variable.
Returns:
None or Parsed information from the variable node.
"""
try:
funcs = {
any_to_str(ast.Constant): lambda x: x.value,
@ -435,9 +707,27 @@ class RepoParser(BaseModel):
@staticmethod
def _parse_assign(node):
"""
Parses an assignment Abstract Syntax Tree (AST) node.
Args:
node: The AST node representing an assignment.
Returns:
None or Parsed information from the assignment node.
"""
return [RepoParser._parse_variable(t) for t in node.targets]
async def rebuild_class_views(self, path: str | Path = None):
"""
Executes `pylint` to reconstruct the dot format class view repository file.
Args:
path (str | Path): The path to the target directory or file. Default is None.
Returns:
None
"""
if not path:
path = self.base_directory
path = Path(path)
@ -458,7 +748,17 @@ class RepoParser(BaseModel):
packages_pathname.unlink(missing_ok=True)
return class_views, relationship_views, package_root
async def _parse_classes(self, class_view_pathname):
@staticmethod
async def _parse_classes(class_view_pathname: Path) -> List[DotClassInfo]:
"""
Parses a dot format class view repository file.
Args:
class_view_pathname (Path): The path to the dot format class view repository file.
Returns:
List[DotClassInfo]: A list of DotClassInfo objects representing the parsed classes.
"""
class_views = []
if not class_view_pathname.exists():
return class_views
@ -490,7 +790,17 @@ class RepoParser(BaseModel):
class_views.append(class_info)
return class_views
async def _parse_class_relationships(self, class_view_pathname) -> List[DotClassRelationship]:
@staticmethod
async def _parse_class_relationships(class_view_pathname: Path) -> List[DotClassRelationship]:
"""
Parses a dot format class view repository file.
Args:
class_view_pathname (Path): The path to the dot format class view repository file.
Returns:
List[DotClassRelationship]: A list of DotClassRelationship objects representing the parsed class relationships.
"""
relationship_views = []
if not class_view_pathname.exists():
return relationship_views
@ -504,7 +814,16 @@ class RepoParser(BaseModel):
return relationship_views
@staticmethod
def _split_class_line(line):
def _split_class_line(line: str) -> (str, str):
"""
Parses a dot format line about class info and returns the class name part and class members part.
Args:
line (str): The dot format line containing class information.
Returns:
Tuple[str, str]: A tuple containing the class name part and class members part.
"""
part_splitor = '" ['
if part_splitor not in line:
return None, None
@ -522,7 +841,17 @@ class RepoParser(BaseModel):
return class_name, info
@staticmethod
def _split_relationship_line(line):
def _split_relationship_line(line: str) -> str:
"""
Parses a dot format line about the relationship of two classes and returns 'Generalize', 'Composite',
or 'Aggregate'.
Args:
line (str): The dot format line containing relationship information.
Returns:
str: The type of relationship, either 'Generalize', 'Composite', or 'Aggregate'.
"""
splitters = [" -> ", " [", "];"]
idxs = []
for tag in splitters:
@ -547,7 +876,16 @@ class RepoParser(BaseModel):
return ret
@staticmethod
def _get_label(line):
def _get_label(line: str) -> str:
"""
Parses a dot format line and returns the label information.
Args:
line (str): The dot format line containing label information.
Returns:
str: The label information parsed from the line.
"""
tag = 'label="'
if tag not in line:
return ""
@ -557,6 +895,15 @@ class RepoParser(BaseModel):
@staticmethod
def _create_path_mapping(path: str | Path) -> Dict[str, str]:
"""
Creates a mapping table between source code files' paths and module names.
Args:
path (str | Path): The path to the source code files or directory.
Returns:
Dict[str, str]: A dictionary mapping source code file paths to their corresponding module names.
"""
mappings = {
str(path).replace("/", "."): str(path),
}
@ -582,6 +929,19 @@ class RepoParser(BaseModel):
def _repair_namespaces(
class_views: List[DotClassInfo], relationship_views: List[DotClassRelationship], path: str | Path
) -> (List[DotClassInfo], List[DotClassRelationship], str):
"""
Augments namespaces to the path-prefixed classes and relationships.
Args:
class_views (List[DotClassInfo]): List of DotClassInfo objects representing class views.
relationship_views (List[DotClassRelationship]): List of DotClassRelationship objects representing
relationships.
path (str | Path): The path to the source code files or directory.
Returns:
Tuple[List[DotClassInfo], List[DotClassRelationship], str]: A tuple containing the augmented class views,
relationships, and the root path of the package.
"""
if not class_views:
return [], [], ""
c = class_views[0]
@ -608,8 +968,19 @@ class RepoParser(BaseModel):
return class_views, relationship_views, root_path
@staticmethod
def _repair_ns(package, mappings):
def _repair_ns(package: str, mappings: Dict[str, str]) -> str:
"""
Replaces the package-prefix with the namespace-prefix.
Args:
package (str): The package to be repaired.
mappings (Dict[str, str]): A dictionary mapping source code file paths to their corresponding packages.
Returns:
str: The repaired namespace.
"""
file_ns = package
ix = 0
while file_ns != "":
if file_ns not in mappings:
ix = file_ns.rfind(".")
@ -621,7 +992,17 @@ class RepoParser(BaseModel):
return ns
@staticmethod
def _find_root(full_key, package) -> str:
def _find_root(full_key: str, package: str) -> str:
"""
Returns the package root path based on the key, which is the full path, and the package information.
Args:
full_key (str): The full key representing the full path.
package (str): The package information.
Returns:
str: The package root path.
"""
left = full_key
while left != "":
if left in package:
@ -634,5 +1015,14 @@ class RepoParser(BaseModel):
return "." + full_key[0:ix]
def is_func(node):
def is_func(node) -> bool:
"""
Returns True if the given node represents a function.
Args:
node: The Abstract Syntax Tree (AST) node.
Returns:
bool: True if the node represents a function, False otherwise.
"""
return isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef))

View file

@ -4,7 +4,7 @@
@Time : 2023/12/19
@Author : mashenquan
@File : visualize_graph.py
@Desc : Visualize the graph.
@Desc : Visualization tool to visualize the class diagrams or sequence diagrams of the graph repository.
"""
from __future__ import annotations
@ -23,13 +23,31 @@ from metagpt.utils.graph_repository import GraphKeyword, GraphRepository
class _VisualClassView(BaseModel):
"""Protected class used by VisualGraphRepo internally.
Attributes:
package (str): The package associated with the class.
uml (Optional[UMLClassView]): Optional UMLClassView associated with the class.
generalizations (List[str]): List of generalizations for the class.
compositions (List[str]): List of compositions for the class.
aggregations (List[str]): List of aggregations for the class.
"""
package: str
uml: Optional[UMLClassView] = None
generalizations: List[str] = Field(default_factory=list)
compositions: List[str] = Field(default_factory=list)
aggregations: List[str] = Field(default_factory=list)
def get_mermaid(self, align: int = 1):
def get_mermaid(self, align: int = 1) -> str:
"""Creates a Markdown Mermaid class diagram text.
Args:
align (int): Indent count used for alignment.
Returns:
str: The Markdown text representing the Mermaid class diagram.
"""
if not self.uml:
return ""
prefix = "\t" * align
@ -45,23 +63,47 @@ class _VisualClassView(BaseModel):
@property
def name(self) -> str:
"""Returns the class name without the namespace prefix."""
return split_namespace(self.package)[-1]
class VisualGraphRepo(ABC):
"""Abstract base class for VisualGraphRepo."""
graph_db: GraphRepository
def __init__(self, graph_db):
"""Initializes a VisualGraphRepo instance with a specified graph database.
Args:
graph_db (GraphRepository): The graph repository used by the VisualGraphRepo.
"""
self.graph_db = graph_db
class VisualDiGraphRepo(VisualGraphRepo):
"""Implementation of VisualGraphRepo for networkx graph repository.
This class extends VisualGraphRepo to provide specific functionality for a graph repository using networkx.
"""
@classmethod
async def load_from(cls, filename: str | Path):
"""Load a VisualDiGraphRepo instance from a file.
Args:
filename (Union[str, Path]): The path to the file containing the graph data.
Returns:
VisualDiGraphRepo: An instance of VisualDiGraphRepo loaded from the specified file.
"""
graph_db = await DiGraphRepository.load_from(str(filename))
return cls(graph_db=graph_db)
async def get_mermaid_class_view(self) -> str:
"""
Returns a Markdown Mermaid class diagram code block object.
"""
rows = await self.graph_db.select(predicate=GraphKeyword.IS, object_=GraphKeyword.CLASS)
mermaid_txt = "classDiagram\n"
for r in rows:
@ -70,6 +112,14 @@ class VisualDiGraphRepo(VisualGraphRepo):
return mermaid_txt
async def _get_class_view(self, ns_class_name: str) -> _VisualClassView:
"""Returns the Markdown Mermaid class diagram code block object for the specified class.
Args:
ns_class_name (str): The namespace-prefixed class name.
Returns:
_VisualClassView: An instance of _VisualClassView representing the class diagram.
"""
rows = await self.graph_db.select(subject=ns_class_name)
class_view = _VisualClassView(package=ns_class_name)
for r in rows:
@ -93,6 +143,12 @@ class VisualDiGraphRepo(VisualGraphRepo):
return class_view
async def get_mermaid_sequence_views(self) -> List[(str, str)]:
"""Returns all Markdown sequence diagrams with their corresponding graph repository keys.
Returns:
List[Tuple[str, str]]: A list of tuples containing Markdown sequence diagrams and their graph repository
keys.
"""
sequence_views = []
rows = await self.graph_db.select(predicate=GraphKeyword.HAS_SEQUENCE_VIEW)
for r in rows:
@ -101,6 +157,14 @@ class VisualDiGraphRepo(VisualGraphRepo):
@staticmethod
def _refine_name(name) -> str:
"""Removes impurity content from the given name.
Args:
name: The name to be refined.
Returns:
str: The refined name.
"""
name = re.sub(r'^[\'"\\\(\)]+|[\'"\\\(\)]+$', "", name)
if name in ["int", "float", "bool", "str", "list", "tuple", "set", "dict", "None"]:
return ""
@ -110,6 +174,12 @@ class VisualDiGraphRepo(VisualGraphRepo):
return name
async def get_mermaid_sequence_view_versions(self) -> List[(str, str)]:
"""Returns all versioned Markdown sequence diagrams with their corresponding graph repository keys.
Returns:
List[Tuple[str, str]]: A list of tuples containing versioned Markdown sequence diagrams and their graph
repository keys.
"""
sequence_views = []
rows = await self.graph_db.select(predicate=GraphKeyword.HAS_SEQUENCE_VIEW_VER)
for r in rows:

View file

@ -4,6 +4,7 @@
@Time : 2024/1/4
@Author : mashenquan
@File : test_rebuild_sequence_view.py
@Desc : Unit tests for reconstructing the sequence diagram from a source code project.
"""
from pathlib import Path
@ -25,16 +26,16 @@ async def test_rebuild(context, mocker):
await context.repo.docs.graph_repo.save(filename=str(graph_db_filename), content=data)
context.git_repo.add_change({f"{GRAPH_REPO_FILE_REPO}/{graph_db_filename}": ChangeType.UNTRACTED})
context.git_repo.commit("commit1")
mock_spo = SPO(
subject="metagpt/startup.py:__name__:__main__",
predicate="has_page_info",
object_='{"lineno":78,"end_lineno":79,"type_name":"ast.If","tokens":["__name__","__main__"],"properties":{}}',
)
# mock_spo = SPO(
# subject="metagpt/tools/search_engine_serpapi.py:__name__:__main__",
# subject="metagpt/startup.py:__name__:__main__",
# predicate="has_page_info",
# object_='{"lineno":113,"end_lineno":116,"type_name":"ast.If","tokens":["__name__","__main__"],"properties":{}}',
# object_='{"lineno":78,"end_lineno":79,"type_name":"ast.If","tokens":["__name__","__main__"],"properties":{}}',
# )
mock_spo = SPO(
subject="metagpt/tools/search_engine_serpapi.py:__name__:__main__",
predicate="has_page_info",
object_='{"lineno":113,"end_lineno":116,"type_name":"ast.If","tokens":["__name__","__main__"],"properties":{}}',
)
mocker.patch.object(RebuildSequenceView, "_search_main_entry", return_value=[mock_spo])
action = RebuildSequenceView(

View file

@ -1,3 +1,12 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2024/1/4
@Author : mashenquan
@File : test_visual_graph_repo.py
@Desc : Unit tests for testing and demonstrating the usage of VisualDiGraphRepo.
"""
import re
from pathlib import Path