diff --git a/metagpt/actions/rebuild_class_view.py b/metagpt/actions/rebuild_class_view.py index adc28ff9d..dbc11d14b 100644 --- a/metagpt/actions/rebuild_class_view.py +++ b/metagpt/actions/rebuild_class_view.py @@ -14,11 +14,16 @@ import aiofiles from metagpt.actions import Action from metagpt.config import CONFIG from metagpt.const import ( - CLASS_VIEW_FILE_REPO, + AGGREGATION, + COMPOSITION, DATA_API_DESIGN_FILE_REPO, + GENERALIZATION, GRAPH_REPO_FILE_REPO, ) +from metagpt.logs import logger from metagpt.repo_parser import RepoParser +from metagpt.schema import ClassAttribute, ClassMethod, ClassView +from metagpt.utils.common import split_namespace from metagpt.utils.di_graph_repository import DiGraphRepository from metagpt.utils.graph_repository import GraphKeyword, GraphRepository @@ -34,34 +39,157 @@ class RebuildClassView(Action): symbols = repo_parser.generate_symbols() # use ast for file_info in symbols: await GraphRepository.update_graph_db_with_file_info(graph_db, file_info) - # await graph_db.save(path=graph_repo_pathname.parent) await self._create_mermaid_class_views(graph_db=graph_db) - await self._save(graph_db=graph_db) + await graph_db.save() async def _create_mermaid_class_views(self, graph_db): path = Path(CONFIG.git_repo.workdir) / DATA_API_DESIGN_FILE_REPO path.mkdir(parents=True, exist_ok=True) - async with aiofiles.open(str(path / CONFIG.git_repo.workdir.name), mode="w", encoding="utf-8") as writer: - await writer.write("classDiagram\n") + pathname = path / CONFIG.git_repo.workdir.name + async with aiofiles.open(str(pathname.with_suffix(".mmd")), mode="w", encoding="utf-8") as writer: + content = "classDiagram\n" + logger.debug(content) + await writer.write(content) # class names rows = await graph_db.select(predicate=GraphKeyword.IS, object_=GraphKeyword.CLASS) - distinct = {} + class_distinct = set() + relationship_distinct = set() for r in rows: - await RebuildClassView._create_mermaid_class(r, graph_db, writer, distinct) + await RebuildClassView._create_mermaid_class(r.subject, graph_db, writer, class_distinct) + for r in rows: + await RebuildClassView._create_mermaid_relationship(r.subject, graph_db, writer, relationship_distinct) @staticmethod async def _create_mermaid_class(ns_class_name, graph_db, file_writer, distinct): - pass - # fields = split_namespace(ns_class_name) - # await graph_db.select(subject=ns_class_name) + fields = split_namespace(ns_class_name) + if len(fields) > 2: + # Ignore sub-class + return - async def _save(self, graph_db): - class_view_file_repo = CONFIG.git_repo.new_file_repository(relative_path=CLASS_VIEW_FILE_REPO) - dataset = await graph_db.select(predicate=GraphKeyword.HAS_CLASS_VIEW) - all_class_view = [] - for spo in dataset: - title = f"---\ntitle: {spo.subject}\n---\n" - filename = re.sub(r"[/:]", "_", spo.subject) + ".mmd" - await class_view_file_repo.save(filename=filename, content=title + spo.object_) - all_class_view.append(spo.object_) - await class_view_file_repo.save(filename="all.mmd", content="\n".join(all_class_view)) + class_view = ClassView(name=fields[1]) + rows = await graph_db.select(subject=ns_class_name) + for r in rows: + name = split_namespace(r.object_)[-1] + name, visibility, abstraction = RebuildClassView._parse_name(name=name, language="python") + if r.predicate == GraphKeyword.HAS_CLASS_PROPERTY: + var_type = await RebuildClassView._parse_variable_type(r.object_, graph_db) + attribute = ClassAttribute( + name=name, visibility=visibility, abstraction=bool(abstraction), value_type=var_type + ) + class_view.attributes.append(attribute) + elif r.predicate == GraphKeyword.HAS_CLASS_FUNCTION: + method = ClassMethod(name=name, visibility=visibility, abstraction=bool(abstraction)) + await RebuildClassView._parse_function_args(method, r.object_, graph_db) + class_view.methods.append(method) + + # update graph db + await graph_db.insert(ns_class_name, GraphKeyword.HAS_CLASS_VIEW, class_view.model_dump_json()) + + content = class_view.get_mermaid(align=1) + logger.debug(content) + await file_writer.write(content) + distinct.add(ns_class_name) + + @staticmethod + async def _create_mermaid_relationship(ns_class_name, graph_db, file_writer, distinct): + s_fields = split_namespace(ns_class_name) + if len(s_fields) > 2: + # Ignore sub-class + return + + predicates = {GraphKeyword.IS + v + GraphKeyword.OF: v for v in [GENERALIZATION, COMPOSITION, AGGREGATION]} + mappings = { + GENERALIZATION: " <|-- ", + COMPOSITION: " *-- ", + AGGREGATION: " o-- ", + } + content = "" + for p, v in predicates.items(): + rows = await graph_db.select(subject=ns_class_name, predicate=p) + for r in rows: + o_fields = split_namespace(r.object_) + if len(o_fields) > 2: + # Ignore sub-class + continue + relationship = mappings.get(v, " .. ") + link = f"{o_fields[1]}{relationship}{s_fields[1]}" + distinct.add(link) + content += f"\t{link}\n" + + if content: + logger.debug(content) + await file_writer.write(content) + + @staticmethod + def _parse_name(name: str, language="python"): + pattern = re.compile(r"(.*?)<\/I>") + result = re.search(pattern, name) + + abstraction = "" + if result: + name = result.group(1) + abstraction = "*" + if name.startswith("__"): + visibility = "-" + elif name.startswith("_"): + visibility = "#" + else: + visibility = "+" + return name, visibility, abstraction + + @staticmethod + async def _parse_variable_type(ns_name, graph_db) -> str: + rows = await graph_db.select(subject=ns_name, predicate=GraphKeyword.HAS_TYPE_DESC) + if not rows: + return "" + vals = rows[0].object_.replace("'", "").split(":") + if len(vals) == 1: + return "" + val = vals[-1].strip() + return "" if val == "NoneType" else val + " " + + @staticmethod + async def _parse_function_args(method: ClassMethod, ns_name: str, graph_db: GraphRepository): + rows = await graph_db.select(subject=ns_name, predicate=GraphKeyword.HAS_ARGS_DESC) + if not rows: + return + info = rows[0].object_.replace("'", "") + + fs_tag = "(" + ix = info.find(fs_tag) + fe_tag = "):" + eix = info.rfind(fe_tag) + if eix < 0: + fe_tag = ")" + eix = info.rfind(fe_tag) + args_info = info[ix + len(fs_tag) : eix].strip() + method.return_type = info[eix + len(fe_tag) :].strip() + if method.return_type == "None": + method.return_type = "" + if "(" in method.return_type: + method.return_type = method.return_type.replace("(", "Tuple[").replace(")", "]") + + # parse args + if not args_info: + return + splitter_ixs = [] + cost = 0 + for i in range(len(args_info)): + if args_info[i] == "[": + cost += 1 + elif args_info[i] == "]": + cost -= 1 + if args_info[i] == "," and cost == 0: + splitter_ixs.append(i) + splitter_ixs.append(len(args_info)) + args = [] + ix = 0 + for eix in splitter_ixs: + args.append(args_info[ix:eix]) + ix = eix + 1 + for arg in args: + parts = arg.strip().split(":") + if len(parts) == 1: + method.args.append(ClassAttribute(name=parts[0].strip())) + continue + method.args.append(ClassAttribute(name=parts[0].strip(), value_type=parts[-1].strip())) diff --git a/metagpt/actions/write_code.py b/metagpt/actions/write_code.py index 25c4912c3..7377442b5 100644 --- a/metagpt/actions/write_code.py +++ b/metagpt/actions/write_code.py @@ -130,7 +130,7 @@ class WriteCode(Action): if not coding_context.code_doc: # avoid root_path pydantic ValidationError if use WriteCode alone root_path = CONFIG.src_workspace if CONFIG.src_workspace else "" - coding_context.code_doc = Document(filename=coding_context.filename, root_path=root_path) + coding_context.code_doc = Document(filename=coding_context.filename, root_path=str(root_path)) coding_context.code_doc.content = code return coding_context diff --git a/metagpt/repo_parser.py b/metagpt/repo_parser.py index f4a9a7f3a..465f40d63 100644 --- a/metagpt/repo_parser.py +++ b/metagpt/repo_parser.py @@ -155,7 +155,8 @@ class RepoParser(BaseModel): else: raise NotImplementedError(f"Not implement:{val}") return code_block - raise NotImplementedError(f"Not implement code block:{node.lineno}, {node.end_lineno}, {any_to_str(node)}") + logger.warning(f"Unsupported code block:{node.lineno}, {node.end_lineno}, {any_to_str(node)}") + return None @staticmethod def _parse_expr(node) -> List: @@ -193,7 +194,7 @@ class RepoParser(BaseModel): tokens.append(v) return tokens except Exception as e: - logger.warning(e) + logger.warning(f"Unsupported if: {n}, err:{e}") return tokens @staticmethod @@ -220,8 +221,7 @@ class RepoParser(BaseModel): raise NotImplementedError(f"Not implement:{node}") return func(node) except Exception as e: - logger.warning(e) - raise e + logger.warning(f"Unsupported variable:{node}, err:{e}") @staticmethod def _parse_assign(node): @@ -274,7 +274,7 @@ class RepoParser(BaseModel): class_views.append(class_info) return class_views - async def _parse_class_relationships(self, class_view_pathname) -> List[ClassRelationShip]: + async def _parse_class_relationships(self, class_view_pathname) -> List[ClassRelationship]: relationship_views = [] if not class_view_pathname.exists(): return relationship_views @@ -365,7 +365,7 @@ class RepoParser(BaseModel): @staticmethod def _repair_namespaces( class_views: List[ClassInfo], relationship_views: List[ClassRelationship], path: str | Path - ) -> (List[ClassInfo], List[ClassRelationShip]): + ) -> (List[ClassInfo], List[ClassRelationship]): if not class_views: return [] c = class_views[0] diff --git a/metagpt/schema.py b/metagpt/schema.py index e36bef395..02d44f767 100644 --- a/metagpt/schema.py +++ b/metagpt/schema.py @@ -451,3 +451,63 @@ class CodeSummarizeContext(BaseModel): class BugFixContext(BaseContext): filename: str = "" + + +# mermaid class view +class ClassMeta(BaseModel): + name: str = "" + abstraction: bool = False + static: bool = False + visibility: str = "" + + +class ClassAttribute(ClassMeta): + value_type: str = "" + default_value: str = "" + + def get_mermaid(self, align=1) -> str: + content = "".join(["\t" for i in range(align)]) + self.visibility + if self.value_type: + content += self.value_type + " " + content += self.name + if self.default_value: + content += "=" + if self.value_type not in ["str", "string", "String"]: + content += self.default_value + else: + content += '"' + self.default_value.replace('"', "") + '"' + if self.abstraction: + content += "*" + if self.static: + content += "$" + return content + + +class ClassMethod(ClassMeta): + args: List[ClassAttribute] = Field(default_factory=list) + return_type: str = "" + + def get_mermaid(self, align=1) -> str: + content = "".join(["\t" for i in range(align)]) + self.visibility + content += self.name + "(" + ",".join([v.get_mermaid(align=0) for v in self.args]) + ")" + if self.return_type: + content += ":" + self.return_type + if self.abstraction: + content += "*" + if self.static: + content += "$" + return content + + +class ClassView(ClassMeta): + attributes: List[ClassAttribute] = Field(default_factory=list) + methods: List[ClassMethod] = Field(default_factory=list) + + def get_mermaid(self, align=1) -> str: + content = "".join(["\t" for i in range(align)]) + "class " + self.name + "{\n" + for v in self.attributes: + content += v.get_mermaid(align=align + 1) + "\n" + for v in self.methods: + content += v.get_mermaid(align=align + 1) + "\n" + content += "".join(["\t" for i in range(align)]) + "}\n" + return content diff --git a/metagpt/utils/common.py b/metagpt/utils/common.py index b5bb41f26..0032f0b0d 100644 --- a/metagpt/utils/common.py +++ b/metagpt/utils/common.py @@ -408,7 +408,7 @@ def concat_namespace(*args) -> str: def split_namespace(ns_class_name: str) -> List[str]: - pass + return ns_class_name.split(":") def general_after_log(i: "loguru.Logger", sec_format: str = "%0.3f") -> typing.Callable[["RetryCallState"], None]: diff --git a/metagpt/utils/graph_repository.py b/metagpt/utils/graph_repository.py index f9eb273f9..88946c98e 100644 --- a/metagpt/utils/graph_repository.py +++ b/metagpt/utils/graph_repository.py @@ -13,6 +13,7 @@ from typing import List from pydantic import BaseModel +from metagpt.logs import logger from metagpt.repo_parser import ClassInfo, ClassRelationship, RepoFileInfo from metagpt.utils.common import concat_namespace @@ -162,6 +163,8 @@ class GraphRepository(ABC): subject=concat_namespace(c.package, vn), predicate=GraphKeyword.HAS_TYPE_DESC, object_=vt ) for fn, desc in c.methods.items(): + if "" in desc and "" not in desc: + logger.error(desc) # class -> function await graph_db.insert( subject=c.package, diff --git a/tests/conftest.py b/tests/conftest.py index d88b31ce5..4caecc8ee 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -9,6 +9,7 @@ import asyncio import logging import re +import uuid from unittest.mock import Mock import pytest @@ -90,9 +91,9 @@ def loguru_caplog(caplog): # init & dispose git repo -@pytest.fixture(scope="session", autouse=True) +@pytest.fixture(scope="function", autouse=True) def setup_and_teardown_git_repo(request): - CONFIG.git_repo = GitRepository(local_path=DEFAULT_WORKSPACE_ROOT / "unittest") + CONFIG.git_repo = GitRepository(local_path=DEFAULT_WORKSPACE_ROOT / f"unittest/{uuid.uuid4().hex}") CONFIG.git_reinit = True # Destroy git repo at the end of the test session. diff --git a/tests/metagpt/actions/test_rebuild_class_view.py b/tests/metagpt/actions/test_rebuild_class_view.py index 941a32a3d..0103e9d05 100644 --- a/tests/metagpt/actions/test_rebuild_class_view.py +++ b/tests/metagpt/actions/test_rebuild_class_view.py @@ -11,6 +11,8 @@ from pathlib import Path import pytest from metagpt.actions.rebuild_class_view import RebuildClassView +from metagpt.config import CONFIG +from metagpt.const import GRAPH_REPO_FILE_REPO from metagpt.llm import LLM @@ -20,6 +22,8 @@ async def test_rebuild(): name="RedBean", context=str(Path(__file__).parent.parent.parent.parent / "metagpt"), llm=LLM() ) await action.run() + graph_file_repo = CONFIG.git_repo.new_file_repository(relative_path=GRAPH_REPO_FILE_REPO) + assert graph_file_repo.changed_files if __name__ == "__main__": diff --git a/tests/metagpt/test_schema.py b/tests/metagpt/test_schema.py index 816c186e2..b6e334fbe 100644 --- a/tests/metagpt/test_schema.py +++ b/tests/metagpt/test_schema.py @@ -19,6 +19,9 @@ from metagpt.config import CONFIG from metagpt.const import SYSTEM_DESIGN_FILE_REPO, TASK_FILE_REPO from metagpt.schema import ( AIMessage, + ClassAttribute, + ClassMethod, + ClassView, CodeSummarizeContext, Document, Message, @@ -156,5 +159,30 @@ def test_CodeSummarizeContext(file_list, want): assert want in m +def test_class_view(): + attr_a = ClassAttribute(name="a", value_type="int", default_value="0", visibility="+", abstraction=True) + assert attr_a.get_mermaid(align=1) == "\t+int a=0*" + attr_b = ClassAttribute(name="b", value_type="str", default_value="0", visibility="#", static=True) + assert attr_b.get_mermaid(align=0) == '#str b="0"$' + class_view = ClassView(name="A") + class_view.attributes = [attr_a, attr_b] + + method_a = ClassMethod(name="run", visibility="+", abstraction=True) + assert method_a.get_mermaid(align=1) == "\t+run()*" + method_b = ClassMethod( + name="_test", + visibility="#", + static=True, + args=[ClassAttribute(name="a", value_type="str"), ClassAttribute(name="b", value_type="int")], + return_type="str", + ) + assert method_b.get_mermaid(align=0) == "#_test(str a,int b):str$" + class_view.methods = [method_a, method_b] + assert ( + class_view.get_mermaid(align=0) + == 'class A{\n\t+int a=0*\n\t#str b="0"$\n\t+run()*\n\t#_test(str a,int b):str$\n}\n' + ) + + if __name__ == "__main__": pytest.main([__file__, "-s"]) diff --git a/tests/metagpt/utils/test_common.py b/tests/metagpt/utils/test_common.py index 0342a92af..9b1fa878e 100644 --- a/tests/metagpt/utils/test_common.py +++ b/tests/metagpt/utils/test_common.py @@ -36,6 +36,7 @@ from metagpt.utils.common import ( read_file_block, read_json_file, require_python_version, + split_namespace, ) @@ -163,6 +164,23 @@ class TestGetProjectRoot: assert concat_namespace("a", "b", "c", "e") == "a:b:c:e" assert concat_namespace("a", "b", "c", "e", "f") == "a:b:c:e:f" + @pytest.mark.parametrize( + ("val", "want"), + [ + ( + "tests/metagpt/test_role.py:test_react:Input:subscription", + ["tests/metagpt/test_role.py", "test_react", "Input", "subscription"], + ), + ( + "tests/metagpt/test_role.py:test_react:Input:goal", + ["tests/metagpt/test_role.py", "test_react", "Input", "goal"], + ), + ], + ) + def test_split_namespace(self, val, want): + res = split_namespace(val) + assert res == want + def test_read_json_file(self): assert read_json_file(str(Path(__file__).parent / "../../data/ut_writer/yft_swaggerApi.json"), encoding="utf-8") with pytest.raises(FileNotFoundError):