feat: +rfc197 example

This commit is contained in:
莘权 马 2024-03-06 19:30:06 +08:00
parent 5cae13fd0a
commit 3fee7a5368
6 changed files with 118 additions and 47 deletions

View file

@ -0,0 +1,72 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import asyncio
import shutil
from pathlib import Path
import typer
from metagpt.actions.rebuild_class_view import RebuildClassView
from metagpt.actions.rebuild_sequence_view import RebuildSequenceView
from metagpt.context import Context
from metagpt.llm import LLM
from metagpt.logs import logger
from metagpt.utils.git_repository import GitRepository
from metagpt.utils.project_repo import ProjectRepo
app = typer.Typer(add_completion=False, pretty_exceptions_show_locals=False)
@app.command("", help="Python project reverse engineering.")
def startup(
project_root: str = typer.Argument(
default="",
help="Specify the root directory of the existing project for reverse engineering.",
),
output_dir: str = typer.Option(default="", help="Specify the output directory path for reverse engineering."),
):
package_root = Path(project_root)
if not package_root.exists():
raise FileNotFoundError(f"{project_root} not exists")
if not _is_python_package_root(package_root):
raise FileNotFoundError(f'There are no "*.py" files under "{project_root}".')
init_file = package_root / "__init__.py" # used by pyreverse
init_file_exists = init_file.exists()
if not init_file_exists:
init_file.touch()
if not output_dir:
output_dir = package_root / "../reverse_engineering_output"
logger.info(f"output dir:{output_dir}")
try:
asyncio.run(reverse_engineering(package_root, Path(output_dir)))
finally:
if not init_file_exists:
init_file.unlink(missing_ok=True)
tmp_dir = package_root / "__dot__"
if tmp_dir.exists():
shutil.rmtree(tmp_dir, ignore_errors=True)
def _is_python_package_root(package_root: Path) -> bool:
for file_path in package_root.iterdir():
if file_path.is_file():
if file_path.suffix == ".py":
return True
return False
async def reverse_engineering(package_root: Path, output_dir: Path):
ctx = Context()
ctx.git_repo = GitRepository(output_dir)
ctx.repo = ProjectRepo(ctx.git_repo)
action = RebuildClassView(name="ReverseEngineering", i_context=str(package_root), llm=LLM(), context=ctx)
await action.run()
action = RebuildSequenceView(name="ReverseEngineering", llm=LLM(), context=ctx)
await action.run()
if __name__ == "__main__":
app()

View file

@ -76,7 +76,7 @@ class RebuildClassView(Action):
path = self.context.git_repo.workdir / DATA_API_DESIGN_FILE_REPO
path.mkdir(parents=True, exist_ok=True)
pathname = path / self.context.git_repo.workdir.name
filename = str(pathname.with_suffix(".mmd"))
filename = str(pathname.with_suffix(".class_diagram.mmd"))
async with aiofiles.open(filename, mode="w", encoding="utf-8") as writer:
content = "classDiagram\n"
logger.debug(content)

View file

@ -12,7 +12,7 @@ from __future__ import annotations
import re
from datetime import datetime
from pathlib import Path
from typing import List, Optional
from typing import List, Optional, Set
from pydantic import BaseModel
from tenacity import retry, stop_after_attempt, wait_random_exponential
@ -125,7 +125,7 @@ class RebuildSequenceView(Action):
if prefix in r.subject:
classes.append(r)
await self._rebuild_use_case(r.subject)
participants = set()
participants = await self._search_participants(split_namespace(entry.subject)[0])
class_details = []
class_views = []
for c in classes:
@ -171,7 +171,8 @@ class RebuildSequenceView(Action):
sequence_view = rsp.removeprefix("```mermaid").removesuffix("```")
rows = await self.graph_db.select(subject=entry.subject, predicate=GraphKeyword.HAS_SEQUENCE_VIEW)
for r in rows:
await self.graph_db.delete(subject=r.subject, predicate=r.predicate, object_=r.object_)
if r.predicate == GraphKeyword.HAS_SEQUENCE_VIEW:
await self.graph_db.delete(subject=r.subject, predicate=r.predicate, object_=r.object_)
await self.graph_db.insert(
subject=entry.subject, predicate=GraphKeyword.HAS_SEQUENCE_VIEW, object_=sequence_view
)
@ -184,7 +185,7 @@ class RebuildSequenceView(Action):
await self.graph_db.insert(
subject=entry.subject, predicate=GraphKeyword.HAS_PARTICIPANT, object_=auto_namespace(c.subject)
)
await self.graph_db.save()
await self._save_sequence_view(subject=entry.subject, content=sequence_view)
async def _merge_sequence_view(self, entry: SPO) -> bool:
"""
@ -267,38 +268,6 @@ class RebuildSequenceView(Action):
prompt_blocks.append(block)
prompt = "\n---\n".join(prompt_blocks)
# class _UseCase(BaseModel):
# description: str = Field(default="...", description="Describes about what the use case to do")
# inputs: List[str] = Field(default=["input name 1", "input name 2"],
# description="Lists the input names of the use case from external sources")
# outputs: List[str] = Field(default=["output name 1", "output name 2"],
# description="Lists the output names of the use case to external sources")
# actors: List[str] = Field(default=["actor name 1", "actor name 2"],
# description="Lists the participant actors of the use case")
# steps: List[str] = Field(default=["Step 1", "Step 2"],
# description="Lists the steps about how the use case works step by step")
# reason: str = Field(default="Because ...",
# description="Explaining under what circumstances would the external system execute this use case.")
#
#
# class _UseCaseList(BaseModel):
# description: str = Field(default="...",
# description="A summary explains what the whole source code want to do")
# use_cases: List[_UseCase] = Field(default=[
# {
# "description": "Describes about what the use case to do",
# "inputs": ["input name 1", "input name 2"],
# "outputs": ["output name 1", "output name 2"],
# "actors": ["actor name 1", "actor name 2"],
# "steps": ["Step 1", "Step 2"],
# "reason": "Because ..."
# }
# ], description="List all use cases.")
# relationship: List[str] = Field(default=["use case 1 ..."],
# description="Lists all the descriptions of relationship among these use cases")
# rsp = await ActionNode.from_pydantic(_UseCaseList).fill(context=prompt, llm=self.llm)
rsp = await self.llm.aask(
msg=prompt,
system_msgs=[
@ -327,7 +296,6 @@ class RebuildSequenceView(Action):
await self.graph_db.insert(
subject=ns_class_name, predicate=GraphKeyword.HAS_CLASS_USE_CASE, object_=detail.model_dump_json()
)
await self.graph_db.save()
@retry(
wait=wait_random_exponential(min=1, max=20),
@ -347,7 +315,6 @@ class RebuildSequenceView(Action):
use_case_markdown = await self._get_class_use_cases(ns_class_name)
if not use_case_markdown: # external class
await self.graph_db.insert(subject=ns_class_name, predicate=GraphKeyword.HAS_SEQUENCE_VIEW, object_="")
await self.graph_db.save()
return
block = f"## Use Cases\n{use_case_markdown}"
prompts_blocks.append(block)
@ -382,7 +349,6 @@ class RebuildSequenceView(Action):
await self.graph_db.insert(
subject=ns_class_name, predicate=GraphKeyword.HAS_SEQUENCE_VIEW, object_=sequence_view
)
await self.graph_db.save()
async def _get_participants(self, ns_class_name: str) -> List[str]:
"""
@ -574,14 +540,12 @@ class RebuildSequenceView(Action):
await self.graph_db.insert(
subject=entry.subject, predicate=GraphKeyword.HAS_PARTICIPANT, object_=concat_namespace("?", class_name)
)
await self.graph_db.save()
return
if len(participants) > 1:
for r in participants:
await self.graph_db.insert(
subject=entry.subject, predicate=GraphKeyword.HAS_PARTICIPANT, object_=auto_namespace(r.subject)
)
await self.graph_db.save()
return
participant = participants[0]
@ -619,4 +583,31 @@ class RebuildSequenceView(Action):
await self.graph_db.insert(
subject=entry.subject, predicate=GraphKeyword.HAS_PARTICIPANT, object_=auto_namespace(participant.subject)
)
await self.graph_db.save()
await self._save_sequence_view(subject=entry.subject, content=sequence_view)
async def _save_sequence_view(self, subject: str, content: str):
pattern = re.compile(r"[^a-zA-Z0-9]")
name = re.sub(pattern, "_", subject)
filename = Path(name).with_suffix(".sequence_diagram.mmd")
await self.context.repo.resources.data_api_design.save(filename=str(filename), content=content)
async def _search_participants(self, filename: str) -> Set:
content = await self._get_source_code(filename)
rsp = await self.llm.aask(
msg=content,
system_msgs=[
"You are a tool for listing all class names used in a source file.",
"Return a markdown JSON object with: "
'- a "class_names" key containing the list of class names used in the file; '
'- a "reasons" key lists all reason objects, each object containing a "class_name" key for class name, a "reference" key explaining the line where the class has been used.',
],
)
class _Data(BaseModel):
class_names: List[str]
reasons: List
json_blocks = parse_json_code_block(rsp)
data = _Data.model_validate_json(json_blocks[0])
return set(data.class_names)

View file

@ -722,14 +722,19 @@ class RepoParser(BaseModel):
path = Path(path)
if not path.exists():
return
init_file = path / "__init__.py"
if not init_file.exists():
raise ValueError("Failed to import module __init__ with error:No module named __init__.")
command = f"pyreverse {str(path)} -o dot"
result = subprocess.run(command, shell=True, check=True, cwd=str(path))
output_dir = path / "__dot__"
output_dir.mkdir(parents=True, exist_ok=True)
result = subprocess.run(command, shell=True, check=True, cwd=str(output_dir))
if result.returncode != 0:
raise ValueError(f"{result}")
class_view_pathname = path / "classes.dot"
class_view_pathname = output_dir / "classes.dot"
class_views = await self._parse_classes(class_view_pathname)
relationship_views = await self._parse_class_relationships(class_view_pathname)
packages_pathname = path / "packages.dot"
packages_pathname = output_dir / "packages.dot"
class_views, relationship_views, package_root = RepoParser._repair_namespaces(
class_views=class_views, relationship_views=relationship_views, path=path
)
@ -975,6 +980,8 @@ class RepoParser(BaseModel):
file_ns = file_ns[0:ix]
continue
break
if file_ns == "":
return ""
internal_ns = package[ix + 1 :]
ns = mappings[file_ns] + ":" + internal_ns.replace(".", ":")
return ns

View file

@ -14,7 +14,6 @@ from metagpt.actions.rebuild_class_view import RebuildClassView
from metagpt.llm import LLM
@pytest.mark.skip
@pytest.mark.asyncio
async def test_rebuild(context):
action = RebuildClassView(

View file

@ -47,6 +47,8 @@ async def test_rebuild(context, mocker):
context=context,
)
await action.run()
rows = await action.graph_db.select()
assert rows
assert context.repo.docs.graph_repo.changed_files