feat: +unit test

fixbug: Align to the same root directory in accordance with `class_views`

fixbug: The class has lost namespace information.

feat: + mock

fixbug: project name invalid

feat: +mermaid sequence diagram

feat: translate from code to mermaid sequence

feat: translate from code to mermaid sequence
This commit is contained in:
莘权 马 2024-01-04 14:27:25 +08:00
parent 1b06d933b2
commit c966138a74
14 changed files with 216 additions and 45 deletions

View file

@ -35,7 +35,6 @@ class PrepareDocuments(Action):
if path.exists() and not CONFIG.inc:
shutil.rmtree(path)
CONFIG.project_path = path
CONFIG.project_name = path.name
CONFIG.git_repo = GitRepository(local_path=path, auto_init=True)
async def run(self, with_messages, **kwargs):

View file

@ -33,11 +33,16 @@ class RebuildClassView(Action):
graph_repo_pathname = CONFIG.git_repo.workdir / GRAPH_REPO_FILE_REPO / CONFIG.git_repo.workdir.name
graph_db = await DiGraphRepository.load_from(str(graph_repo_pathname.with_suffix(".json")))
repo_parser = RepoParser(base_directory=Path(self.context))
class_views, relationship_views = await repo_parser.rebuild_class_views(path=Path(self.context)) # use pylint
# use pylint
class_views, relationship_views, package_root = await repo_parser.rebuild_class_views(path=Path(self.context))
await GraphRepository.update_graph_db_with_class_views(graph_db, class_views)
await GraphRepository.update_graph_db_with_class_relationship_views(graph_db, relationship_views)
symbols = repo_parser.generate_symbols() # use ast
# use ast
direction, diff_path = self._diff_path(path_root=Path(self.context).resolve(), package_root=package_root)
symbols = repo_parser.generate_symbols()
for file_info in symbols:
# Align to the same root directory in accordance with `class_views`.
file_info.file = self._align_root(file_info.file, direction, diff_path)
await GraphRepository.update_graph_db_with_file_info(graph_db, file_info)
await self._create_mermaid_class_views(graph_db=graph_db)
await graph_db.save()
@ -193,3 +198,20 @@ class RebuildClassView(Action):
method.args.append(ClassAttribute(name=parts[0].strip()))
continue
method.args.append(ClassAttribute(name=parts[0].strip(), value_type=parts[-1].strip()))
@staticmethod
def _diff_path(path_root: Path, package_root: Path) -> (str, str):
if len(str(path_root)) > len(str(package_root)):
return "+", str(path_root.relative_to(package_root))
if len(str(path_root)) < len(str(package_root)):
return "-", str(package_root.relative_to(path_root))
return "=", "."
@staticmethod
def _align_root(path: str, direction: str, diff_path: str):
if direction == "=":
return path
if direction == "+":
return diff_path + "/" + path
else:
return path[len(diff_path) + 1 :]

View file

@ -1,33 +0,0 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2023/12/19
@Author : mashenquan
@File : rebuild_class_view_an.py
@Desc : Defines `ActionNode` objects used by rebuild_class_view.py
"""
from metagpt.actions.action_node import ActionNode
CLASS_SOURCE_CODE_BLOCK = ActionNode(
key="Class View",
expected_type=str,
instruction='Generate the mermaid class diagram corresponding to source code in "context."',
example="""
classDiagram
class A {
-int x
+int y
-int speed
-int direction
+__init__(x: int, y: int, speed: int, direction: int)
+change_direction(new_direction: int) None
+move() None
}
""",
)
REBUILD_CLASS_VIEW_NODES = [
CLASS_SOURCE_CODE_BLOCK,
]
REBUILD_CLASS_VIEW_NODE = ActionNode.from_children("RebuildClassView", REBUILD_CLASS_VIEW_NODES)

View file

@ -0,0 +1,60 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2024/1/4
@Author : mashenquan
@File : rebuild_sequence_view.py
@Desc : Rebuild sequence view info
"""
from __future__ import annotations
from pathlib import Path
from typing import List
from metagpt.actions import Action
from metagpt.config import CONFIG
from metagpt.const import GRAPH_REPO_FILE_REPO
from metagpt.logs import logger
from metagpt.utils.common import aread, list_files
from metagpt.utils.di_graph_repository import DiGraphRepository
from metagpt.utils.graph_repository import GraphKeyword
class RebuildSequenceView(Action):
async def run(self, with_messages=None, format=CONFIG.prompt_schema):
graph_repo_pathname = CONFIG.git_repo.workdir / GRAPH_REPO_FILE_REPO / CONFIG.git_repo.workdir.name
graph_db = await DiGraphRepository.load_from(str(graph_repo_pathname.with_suffix(".json")))
entries = await RebuildSequenceView._search_main_entry(graph_db)
for entry in entries:
await self._rebuild_sequence_view(entry, graph_db)
await graph_db.save()
@staticmethod
async def _search_main_entry(graph_db) -> List:
rows = await graph_db.select(predicate=GraphKeyword.HAS_PAGE_INFO)
tag = "__name__:__main__"
entries = []
for r in rows:
if tag in r.subject or tag in r.object_:
entries.append(r)
return entries
async def _rebuild_sequence_view(self, entry, graph_db):
filename = entry.subject.split(":", 1)[0]
src_filename = RebuildSequenceView._get_full_filename(root=self.context, pathname=filename)
content = await aread(filename=src_filename, encoding="utf-8")
content = f"```python\n{content}\n```\n\n---\nTranslate the code above into Mermaid Sequence Diagram."
data = await self.llm.aask(
msg=content, system_msgs=["You are a python code to Mermaid Sequence Diagram translator in function detail"]
)
await graph_db.insert(subject=filename, predicate=GraphKeyword.HAS_SEQUENCE_VIEW, object_=data)
logger.info(data)
@staticmethod
def _get_full_filename(root: str | Path, pathname: str | Path) -> Path | None:
files = list_files(root=root)
postfix = "/" + str(pathname)
for i in files:
if str(i).endswith(postfix):
return i
return None

View file

@ -0,0 +1,16 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2024/1/4
@Author : mashenquan
@File : rebuild_sequence_view_an.py
"""
from metagpt.actions.action_node import ActionNode
from metagpt.utils.mermaid import MMC2
CODE_2_MERMAID_SEQUENCE_DIAGRAM = ActionNode(
key="Program call flow",
expected_type=str,
instruction='Translate the "context" content into "format example" format.',
example=MMC2,
)

View file

@ -14,6 +14,7 @@
from __future__ import annotations
import json
import uuid
from pathlib import Path
from typing import Optional
@ -117,7 +118,7 @@ class WritePRD(Action):
# if sas.result:
# logger.info(sas.result)
# logger.info(rsp)
project_name = CONFIG.project_name if CONFIG.project_name else ""
project_name = CONFIG.project_name or ""
context = CONTEXT_TEMPLATE.format(requirements=requirements, project_name=project_name)
exclude = [PROJECT_NAME.key] if project_name else []
node = await WRITE_PRD_NODE.fill(context=context, llm=self.llm, exclude=exclude) # schema=schema
@ -183,6 +184,8 @@ class WritePRD(Action):
ws_name = CodeParser.parse_str(block="Project Name", text=prd)
if ws_name:
CONFIG.project_name = ws_name
if not CONFIG.project_name: # The LLM failed to provide a project name, and the user didn't provide one either.
CONFIG.project_name = "app" + uuid.uuid4().hex[:16]
CONFIG.git_repo.rename_root(CONFIG.project_name)
async def _is_bugfix(self, context) -> bool:

View file

@ -240,12 +240,12 @@ class RepoParser(BaseModel):
class_views = await self._parse_classes(class_view_pathname)
relationship_views = await self._parse_class_relationships(class_view_pathname)
packages_pathname = path / "packages.dot"
class_views, relationship_views = RepoParser._repair_namespaces(
class_views, relationship_views, package_root = RepoParser._repair_namespaces(
class_views=class_views, relationship_views=relationship_views, path=path
)
class_view_pathname.unlink(missing_ok=True)
packages_pathname.unlink(missing_ok=True)
return class_views, relationship_views
return class_views, relationship_views, package_root
async def _parse_classes(self, class_view_pathname):
class_views = []
@ -364,9 +364,9 @@ class RepoParser(BaseModel):
@staticmethod
def _repair_namespaces(
class_views: List[ClassInfo], relationship_views: List[ClassRelationship], path: str | Path
) -> (List[ClassInfo], List[ClassRelationship]):
) -> (List[ClassInfo], List[ClassRelationship], str):
if not class_views:
return []
return [], [], ""
c = class_views[0]
full_key = str(path).lstrip("/").replace("/", ".")
root_namespace = RepoParser._find_root(full_key, c.package)
@ -388,7 +388,7 @@ class RepoParser(BaseModel):
v.src = RepoParser._repair_ns(v.src, new_mappings)
v.dest = RepoParser._repair_ns(v.dest, new_mappings)
relationship_views[i] = v
return class_views, relationship_views
return class_views, relationship_views, root_path
@staticmethod
def _repair_ns(package, mappings):

View file

@ -550,3 +550,20 @@ async def read_file_block(filename: str | Path, lineno: int, end_lineno: int):
break
lines.append(line)
return "".join(lines)
def list_files(root: str | Path) -> List[Path]:
files = []
try:
directory_path = Path(root)
if not directory_path.exists():
return []
for file_path in directory_path.iterdir():
if file_path.is_file():
files.append(file_path)
else:
subfolder_files = list_files(root=file_path)
files.extend(subfolder_files)
except Exception as e:
logger.error(f"Error: {e}")
return files

View file

@ -135,12 +135,12 @@ class GraphRepository(ABC):
@staticmethod
async def update_graph_db_with_class_views(graph_db: "GraphRepository", class_views: List[ClassInfo]):
for c in class_views:
filename, class_name = c.package.split(":", 1)
filename, _ = c.package.split(":", 1)
await graph_db.insert(subject=filename, predicate=GraphKeyword.IS, object_=GraphKeyword.SOURCE_CODE)
file_types = {".py": "python", ".js": "javascript"}
file_type = file_types.get(Path(filename).suffix, GraphKeyword.NULL)
await graph_db.insert(subject=filename, predicate=GraphKeyword.IS, object_=file_type)
await graph_db.insert(subject=filename, predicate=GraphKeyword.HAS_CLASS, object_=class_name)
await graph_db.insert(subject=filename, predicate=GraphKeyword.HAS_CLASS, object_=c.package)
await graph_db.insert(
subject=c.package,
predicate=GraphKeyword.IS,

View file

@ -46,6 +46,7 @@ extras_require["test"] = [
"chromadb==0.4.14",
"gradio==3.0.0",
"grpcio-status==1.48.2",
"mock==5.1.0",
]
extras_require["pyppeteer"] = [

File diff suppressed because one or more lines are too long

View file

@ -26,5 +26,32 @@ async def test_rebuild():
assert graph_file_repo.changed_files
@pytest.mark.parametrize(
("path", "direction", "diff", "want"),
[
("metagpt/startup.py", "=", ".", "metagpt/startup.py"),
("metagpt/startup.py", "+", "MetaGPT", "MetaGPT/metagpt/startup.py"),
("metagpt/startup.py", "-", "metagpt", "startup.py"),
],
)
def test_align_path(path, direction, diff, want):
res = RebuildClassView._align_root(path=path, direction=direction, diff_path=diff)
assert res == want
@pytest.mark.parametrize(
("path_root", "package_root", "want_direction", "want_diff"),
[
("/Users/x/github/MetaGPT/metagpt", "/Users/x/github/MetaGPT/metagpt", "=", "."),
("/Users/x/github/MetaGPT", "/Users/x/github/MetaGPT/metagpt", "-", "metagpt"),
("/Users/x/github/MetaGPT/metagpt", "/Users/x/github/MetaGPT", "+", "metagpt"),
],
)
def test_diff_path(path_root, package_root, want_direction, want_diff):
direction, diff = RebuildClassView._diff_path(path_root=Path(path_root), package_root=Path(package_root))
assert direction == want_direction
assert diff == want_diff
if __name__ == "__main__":
pytest.main([__file__, "-s"])

View file

@ -0,0 +1,55 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2024/1/4
@Author : mashenquan
@File : test_rebuild_sequence_view.py
"""
from pathlib import Path
import pytest
from metagpt.actions.rebuild_sequence_view import RebuildSequenceView
from metagpt.config import CONFIG
from metagpt.const import GRAPH_REPO_FILE_REPO
from metagpt.llm import LLM
from metagpt.utils.common import aread
from metagpt.utils.file_repository import FileRepository
from metagpt.utils.git_repository import ChangeType
@pytest.mark.asyncio
async def test_rebuild():
# Mock
data = await aread(filename=Path(__file__).parent / "../../data/graph_db/networkx.json")
graph_db_filename = Path(CONFIG.git_repo.workdir.name).with_suffix(".json")
await FileRepository.save_file(
filename=str(graph_db_filename),
relative_path=GRAPH_REPO_FILE_REPO,
content=data,
)
CONFIG.git_repo.add_change({f"{GRAPH_REPO_FILE_REPO}/{graph_db_filename}": ChangeType.UNTRACTED})
CONFIG.git_repo.commit("commit1")
action = RebuildSequenceView(
name="RedBean", context=str(Path(__file__).parent.parent.parent.parent / "metagpt"), llm=LLM()
)
await action.run()
graph_file_repo = CONFIG.git_repo.new_file_repository(relative_path=GRAPH_REPO_FILE_REPO)
assert graph_file_repo.changed_files
@pytest.mark.parametrize(
("root", "pathname", "want"),
[
(Path(__file__).parent.parent.parent, "/".join(__file__.split("/")[-2:]), Path(__file__)),
(Path(__file__).parent.parent.parent, "f/g.txt", None),
],
)
def test_get_full_filename(root, pathname, want):
res = RebuildSequenceView._get_full_filename(root=root, pathname=pathname)
assert res == want
if __name__ == "__main__":
pytest.main([__file__, "-s"])

View file

@ -6,6 +6,8 @@
@File : test_skill_loader.py
@Desc : Unit tests.
"""
from pathlib import Path
import pytest
from metagpt.config import CONFIG
@ -23,7 +25,8 @@ async def test_suite():
{"id": 6, "name": "knowledge", "type": "builtin", "config": {}, "enabled": True},
{"id": 6, "name": "web_search", "type": "builtin", "config": {}, "enabled": True},
]
loader = await SkillsDeclaration.load()
pathname = Path(__file__).parent / "../../../docs/.well-known/skills.yaml"
loader = await SkillsDeclaration.load(skill_yaml_file_name=pathname)
skills = loader.get_skill_list()
assert skills
assert len(skills) >= 3