diff --git a/metagpt/actions/rebuild_class_view.py b/metagpt/actions/rebuild_class_view.py
index adc28ff9d..dbc11d14b 100644
--- a/metagpt/actions/rebuild_class_view.py
+++ b/metagpt/actions/rebuild_class_view.py
@@ -14,11 +14,16 @@ import aiofiles
from metagpt.actions import Action
from metagpt.config import CONFIG
from metagpt.const import (
- CLASS_VIEW_FILE_REPO,
+ AGGREGATION,
+ COMPOSITION,
DATA_API_DESIGN_FILE_REPO,
+ GENERALIZATION,
GRAPH_REPO_FILE_REPO,
)
+from metagpt.logs import logger
from metagpt.repo_parser import RepoParser
+from metagpt.schema import ClassAttribute, ClassMethod, ClassView
+from metagpt.utils.common import split_namespace
from metagpt.utils.di_graph_repository import DiGraphRepository
from metagpt.utils.graph_repository import GraphKeyword, GraphRepository
@@ -34,34 +39,157 @@ class RebuildClassView(Action):
symbols = repo_parser.generate_symbols() # use ast
for file_info in symbols:
await GraphRepository.update_graph_db_with_file_info(graph_db, file_info)
- # await graph_db.save(path=graph_repo_pathname.parent)
await self._create_mermaid_class_views(graph_db=graph_db)
- await self._save(graph_db=graph_db)
+ await graph_db.save()
async def _create_mermaid_class_views(self, graph_db):
path = Path(CONFIG.git_repo.workdir) / DATA_API_DESIGN_FILE_REPO
path.mkdir(parents=True, exist_ok=True)
- async with aiofiles.open(str(path / CONFIG.git_repo.workdir.name), mode="w", encoding="utf-8") as writer:
- await writer.write("classDiagram\n")
+ pathname = path / CONFIG.git_repo.workdir.name
+ async with aiofiles.open(str(pathname.with_suffix(".mmd")), mode="w", encoding="utf-8") as writer:
+ content = "classDiagram\n"
+ logger.debug(content)
+ await writer.write(content)
# class names
rows = await graph_db.select(predicate=GraphKeyword.IS, object_=GraphKeyword.CLASS)
- distinct = {}
+ class_distinct = set()
+ relationship_distinct = set()
for r in rows:
- await RebuildClassView._create_mermaid_class(r, graph_db, writer, distinct)
+ await RebuildClassView._create_mermaid_class(r.subject, graph_db, writer, class_distinct)
+ for r in rows:
+ await RebuildClassView._create_mermaid_relationship(r.subject, graph_db, writer, relationship_distinct)
@staticmethod
async def _create_mermaid_class(ns_class_name, graph_db, file_writer, distinct):
- pass
- # fields = split_namespace(ns_class_name)
- # await graph_db.select(subject=ns_class_name)
+ fields = split_namespace(ns_class_name)
+ if len(fields) > 2:
+ # Ignore sub-class
+ return
- async def _save(self, graph_db):
- class_view_file_repo = CONFIG.git_repo.new_file_repository(relative_path=CLASS_VIEW_FILE_REPO)
- dataset = await graph_db.select(predicate=GraphKeyword.HAS_CLASS_VIEW)
- all_class_view = []
- for spo in dataset:
- title = f"---\ntitle: {spo.subject}\n---\n"
- filename = re.sub(r"[/:]", "_", spo.subject) + ".mmd"
- await class_view_file_repo.save(filename=filename, content=title + spo.object_)
- all_class_view.append(spo.object_)
- await class_view_file_repo.save(filename="all.mmd", content="\n".join(all_class_view))
+ class_view = ClassView(name=fields[1])
+ rows = await graph_db.select(subject=ns_class_name)
+ for r in rows:
+ name = split_namespace(r.object_)[-1]
+ name, visibility, abstraction = RebuildClassView._parse_name(name=name, language="python")
+ if r.predicate == GraphKeyword.HAS_CLASS_PROPERTY:
+ var_type = await RebuildClassView._parse_variable_type(r.object_, graph_db)
+ attribute = ClassAttribute(
+ name=name, visibility=visibility, abstraction=bool(abstraction), value_type=var_type
+ )
+ class_view.attributes.append(attribute)
+ elif r.predicate == GraphKeyword.HAS_CLASS_FUNCTION:
+ method = ClassMethod(name=name, visibility=visibility, abstraction=bool(abstraction))
+ await RebuildClassView._parse_function_args(method, r.object_, graph_db)
+ class_view.methods.append(method)
+
+ # update graph db
+ await graph_db.insert(ns_class_name, GraphKeyword.HAS_CLASS_VIEW, class_view.model_dump_json())
+
+ content = class_view.get_mermaid(align=1)
+ logger.debug(content)
+ await file_writer.write(content)
+ distinct.add(ns_class_name)
+
+ @staticmethod
+ async def _create_mermaid_relationship(ns_class_name, graph_db, file_writer, distinct):
+ s_fields = split_namespace(ns_class_name)
+ if len(s_fields) > 2:
+ # Ignore sub-class
+ return
+
+ predicates = {GraphKeyword.IS + v + GraphKeyword.OF: v for v in [GENERALIZATION, COMPOSITION, AGGREGATION]}
+ mappings = {
+ GENERALIZATION: " <|-- ",
+ COMPOSITION: " *-- ",
+ AGGREGATION: " o-- ",
+ }
+ content = ""
+ for p, v in predicates.items():
+ rows = await graph_db.select(subject=ns_class_name, predicate=p)
+ for r in rows:
+ o_fields = split_namespace(r.object_)
+ if len(o_fields) > 2:
+ # Ignore sub-class
+ continue
+ relationship = mappings.get(v, " .. ")
+ link = f"{o_fields[1]}{relationship}{s_fields[1]}"
+ distinct.add(link)
+ content += f"\t{link}\n"
+
+ if content:
+ logger.debug(content)
+ await file_writer.write(content)
+
+ @staticmethod
+ def _parse_name(name: str, language="python"):
+ pattern = re.compile(r"(.*?)<\/I>")
+ result = re.search(pattern, name)
+
+ abstraction = ""
+ if result:
+ name = result.group(1)
+ abstraction = "*"
+ if name.startswith("__"):
+ visibility = "-"
+ elif name.startswith("_"):
+ visibility = "#"
+ else:
+ visibility = "+"
+ return name, visibility, abstraction
+
+ @staticmethod
+ async def _parse_variable_type(ns_name, graph_db) -> str:
+ rows = await graph_db.select(subject=ns_name, predicate=GraphKeyword.HAS_TYPE_DESC)
+ if not rows:
+ return ""
+ vals = rows[0].object_.replace("'", "").split(":")
+ if len(vals) == 1:
+ return ""
+ val = vals[-1].strip()
+ return "" if val == "NoneType" else val + " "
+
+ @staticmethod
+ async def _parse_function_args(method: ClassMethod, ns_name: str, graph_db: GraphRepository):
+ rows = await graph_db.select(subject=ns_name, predicate=GraphKeyword.HAS_ARGS_DESC)
+ if not rows:
+ return
+ info = rows[0].object_.replace("'", "")
+
+ fs_tag = "("
+ ix = info.find(fs_tag)
+ fe_tag = "):"
+ eix = info.rfind(fe_tag)
+ if eix < 0:
+ fe_tag = ")"
+ eix = info.rfind(fe_tag)
+ args_info = info[ix + len(fs_tag) : eix].strip()
+ method.return_type = info[eix + len(fe_tag) :].strip()
+ if method.return_type == "None":
+ method.return_type = ""
+ if "(" in method.return_type:
+ method.return_type = method.return_type.replace("(", "Tuple[").replace(")", "]")
+
+ # parse args
+ if not args_info:
+ return
+ splitter_ixs = []
+ cost = 0
+ for i in range(len(args_info)):
+ if args_info[i] == "[":
+ cost += 1
+ elif args_info[i] == "]":
+ cost -= 1
+ if args_info[i] == "," and cost == 0:
+ splitter_ixs.append(i)
+ splitter_ixs.append(len(args_info))
+ args = []
+ ix = 0
+ for eix in splitter_ixs:
+ args.append(args_info[ix:eix])
+ ix = eix + 1
+ for arg in args:
+ parts = arg.strip().split(":")
+ if len(parts) == 1:
+ method.args.append(ClassAttribute(name=parts[0].strip()))
+ continue
+ method.args.append(ClassAttribute(name=parts[0].strip(), value_type=parts[-1].strip()))
diff --git a/metagpt/actions/write_code.py b/metagpt/actions/write_code.py
index 25c4912c3..7377442b5 100644
--- a/metagpt/actions/write_code.py
+++ b/metagpt/actions/write_code.py
@@ -130,7 +130,7 @@ class WriteCode(Action):
if not coding_context.code_doc:
# avoid root_path pydantic ValidationError if use WriteCode alone
root_path = CONFIG.src_workspace if CONFIG.src_workspace else ""
- coding_context.code_doc = Document(filename=coding_context.filename, root_path=root_path)
+ coding_context.code_doc = Document(filename=coding_context.filename, root_path=str(root_path))
coding_context.code_doc.content = code
return coding_context
diff --git a/metagpt/repo_parser.py b/metagpt/repo_parser.py
index f4a9a7f3a..465f40d63 100644
--- a/metagpt/repo_parser.py
+++ b/metagpt/repo_parser.py
@@ -155,7 +155,8 @@ class RepoParser(BaseModel):
else:
raise NotImplementedError(f"Not implement:{val}")
return code_block
- raise NotImplementedError(f"Not implement code block:{node.lineno}, {node.end_lineno}, {any_to_str(node)}")
+ logger.warning(f"Unsupported code block:{node.lineno}, {node.end_lineno}, {any_to_str(node)}")
+ return None
@staticmethod
def _parse_expr(node) -> List:
@@ -193,7 +194,7 @@ class RepoParser(BaseModel):
tokens.append(v)
return tokens
except Exception as e:
- logger.warning(e)
+ logger.warning(f"Unsupported if: {n}, err:{e}")
return tokens
@staticmethod
@@ -220,8 +221,7 @@ class RepoParser(BaseModel):
raise NotImplementedError(f"Not implement:{node}")
return func(node)
except Exception as e:
- logger.warning(e)
- raise e
+ logger.warning(f"Unsupported variable:{node}, err:{e}")
@staticmethod
def _parse_assign(node):
@@ -274,7 +274,7 @@ class RepoParser(BaseModel):
class_views.append(class_info)
return class_views
- async def _parse_class_relationships(self, class_view_pathname) -> List[ClassRelationShip]:
+ async def _parse_class_relationships(self, class_view_pathname) -> List[ClassRelationship]:
relationship_views = []
if not class_view_pathname.exists():
return relationship_views
@@ -365,7 +365,7 @@ class RepoParser(BaseModel):
@staticmethod
def _repair_namespaces(
class_views: List[ClassInfo], relationship_views: List[ClassRelationship], path: str | Path
- ) -> (List[ClassInfo], List[ClassRelationShip]):
+ ) -> (List[ClassInfo], List[ClassRelationship]):
if not class_views:
return []
c = class_views[0]
diff --git a/metagpt/schema.py b/metagpt/schema.py
index e36bef395..02d44f767 100644
--- a/metagpt/schema.py
+++ b/metagpt/schema.py
@@ -451,3 +451,63 @@ class CodeSummarizeContext(BaseModel):
class BugFixContext(BaseContext):
filename: str = ""
+
+
+# mermaid class view
+class ClassMeta(BaseModel):
+ name: str = ""
+ abstraction: bool = False
+ static: bool = False
+ visibility: str = ""
+
+
+class ClassAttribute(ClassMeta):
+ value_type: str = ""
+ default_value: str = ""
+
+ def get_mermaid(self, align=1) -> str:
+ content = "".join(["\t" for i in range(align)]) + self.visibility
+ if self.value_type:
+ content += self.value_type + " "
+ content += self.name
+ if self.default_value:
+ content += "="
+ if self.value_type not in ["str", "string", "String"]:
+ content += self.default_value
+ else:
+ content += '"' + self.default_value.replace('"', "") + '"'
+ if self.abstraction:
+ content += "*"
+ if self.static:
+ content += "$"
+ return content
+
+
+class ClassMethod(ClassMeta):
+ args: List[ClassAttribute] = Field(default_factory=list)
+ return_type: str = ""
+
+ def get_mermaid(self, align=1) -> str:
+ content = "".join(["\t" for i in range(align)]) + self.visibility
+ content += self.name + "(" + ",".join([v.get_mermaid(align=0) for v in self.args]) + ")"
+ if self.return_type:
+ content += ":" + self.return_type
+ if self.abstraction:
+ content += "*"
+ if self.static:
+ content += "$"
+ return content
+
+
+class ClassView(ClassMeta):
+ attributes: List[ClassAttribute] = Field(default_factory=list)
+ methods: List[ClassMethod] = Field(default_factory=list)
+
+ def get_mermaid(self, align=1) -> str:
+ content = "".join(["\t" for i in range(align)]) + "class " + self.name + "{\n"
+ for v in self.attributes:
+ content += v.get_mermaid(align=align + 1) + "\n"
+ for v in self.methods:
+ content += v.get_mermaid(align=align + 1) + "\n"
+ content += "".join(["\t" for i in range(align)]) + "}\n"
+ return content
diff --git a/metagpt/utils/common.py b/metagpt/utils/common.py
index b5bb41f26..0032f0b0d 100644
--- a/metagpt/utils/common.py
+++ b/metagpt/utils/common.py
@@ -408,7 +408,7 @@ def concat_namespace(*args) -> str:
def split_namespace(ns_class_name: str) -> List[str]:
- pass
+ return ns_class_name.split(":")
def general_after_log(i: "loguru.Logger", sec_format: str = "%0.3f") -> typing.Callable[["RetryCallState"], None]:
diff --git a/metagpt/utils/graph_repository.py b/metagpt/utils/graph_repository.py
index f9eb273f9..88946c98e 100644
--- a/metagpt/utils/graph_repository.py
+++ b/metagpt/utils/graph_repository.py
@@ -13,6 +13,7 @@ from typing import List
from pydantic import BaseModel
+from metagpt.logs import logger
from metagpt.repo_parser import ClassInfo, ClassRelationship, RepoFileInfo
from metagpt.utils.common import concat_namespace
@@ -162,6 +163,8 @@ class GraphRepository(ABC):
subject=concat_namespace(c.package, vn), predicate=GraphKeyword.HAS_TYPE_DESC, object_=vt
)
for fn, desc in c.methods.items():
+ if "" in desc and "" not in desc:
+ logger.error(desc)
# class -> function
await graph_db.insert(
subject=c.package,
diff --git a/tests/conftest.py b/tests/conftest.py
index d88b31ce5..4caecc8ee 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -9,6 +9,7 @@
import asyncio
import logging
import re
+import uuid
from unittest.mock import Mock
import pytest
@@ -90,9 +91,9 @@ def loguru_caplog(caplog):
# init & dispose git repo
-@pytest.fixture(scope="session", autouse=True)
+@pytest.fixture(scope="function", autouse=True)
def setup_and_teardown_git_repo(request):
- CONFIG.git_repo = GitRepository(local_path=DEFAULT_WORKSPACE_ROOT / "unittest")
+ CONFIG.git_repo = GitRepository(local_path=DEFAULT_WORKSPACE_ROOT / f"unittest/{uuid.uuid4().hex}")
CONFIG.git_reinit = True
# Destroy git repo at the end of the test session.
diff --git a/tests/metagpt/actions/test_rebuild_class_view.py b/tests/metagpt/actions/test_rebuild_class_view.py
index 941a32a3d..0103e9d05 100644
--- a/tests/metagpt/actions/test_rebuild_class_view.py
+++ b/tests/metagpt/actions/test_rebuild_class_view.py
@@ -11,6 +11,8 @@ from pathlib import Path
import pytest
from metagpt.actions.rebuild_class_view import RebuildClassView
+from metagpt.config import CONFIG
+from metagpt.const import GRAPH_REPO_FILE_REPO
from metagpt.llm import LLM
@@ -20,6 +22,8 @@ async def test_rebuild():
name="RedBean", context=str(Path(__file__).parent.parent.parent.parent / "metagpt"), llm=LLM()
)
await action.run()
+ graph_file_repo = CONFIG.git_repo.new_file_repository(relative_path=GRAPH_REPO_FILE_REPO)
+ assert graph_file_repo.changed_files
if __name__ == "__main__":
diff --git a/tests/metagpt/test_schema.py b/tests/metagpt/test_schema.py
index 816c186e2..b6e334fbe 100644
--- a/tests/metagpt/test_schema.py
+++ b/tests/metagpt/test_schema.py
@@ -19,6 +19,9 @@ from metagpt.config import CONFIG
from metagpt.const import SYSTEM_DESIGN_FILE_REPO, TASK_FILE_REPO
from metagpt.schema import (
AIMessage,
+ ClassAttribute,
+ ClassMethod,
+ ClassView,
CodeSummarizeContext,
Document,
Message,
@@ -156,5 +159,30 @@ def test_CodeSummarizeContext(file_list, want):
assert want in m
+def test_class_view():
+ attr_a = ClassAttribute(name="a", value_type="int", default_value="0", visibility="+", abstraction=True)
+ assert attr_a.get_mermaid(align=1) == "\t+int a=0*"
+ attr_b = ClassAttribute(name="b", value_type="str", default_value="0", visibility="#", static=True)
+ assert attr_b.get_mermaid(align=0) == '#str b="0"$'
+ class_view = ClassView(name="A")
+ class_view.attributes = [attr_a, attr_b]
+
+ method_a = ClassMethod(name="run", visibility="+", abstraction=True)
+ assert method_a.get_mermaid(align=1) == "\t+run()*"
+ method_b = ClassMethod(
+ name="_test",
+ visibility="#",
+ static=True,
+ args=[ClassAttribute(name="a", value_type="str"), ClassAttribute(name="b", value_type="int")],
+ return_type="str",
+ )
+ assert method_b.get_mermaid(align=0) == "#_test(str a,int b):str$"
+ class_view.methods = [method_a, method_b]
+ assert (
+ class_view.get_mermaid(align=0)
+ == 'class A{\n\t+int a=0*\n\t#str b="0"$\n\t+run()*\n\t#_test(str a,int b):str$\n}\n'
+ )
+
+
if __name__ == "__main__":
pytest.main([__file__, "-s"])
diff --git a/tests/metagpt/utils/test_common.py b/tests/metagpt/utils/test_common.py
index 0342a92af..9b1fa878e 100644
--- a/tests/metagpt/utils/test_common.py
+++ b/tests/metagpt/utils/test_common.py
@@ -36,6 +36,7 @@ from metagpt.utils.common import (
read_file_block,
read_json_file,
require_python_version,
+ split_namespace,
)
@@ -163,6 +164,23 @@ class TestGetProjectRoot:
assert concat_namespace("a", "b", "c", "e") == "a:b:c:e"
assert concat_namespace("a", "b", "c", "e", "f") == "a:b:c:e:f"
+ @pytest.mark.parametrize(
+ ("val", "want"),
+ [
+ (
+ "tests/metagpt/test_role.py:test_react:Input:subscription",
+ ["tests/metagpt/test_role.py", "test_react", "Input", "subscription"],
+ ),
+ (
+ "tests/metagpt/test_role.py:test_react:Input:goal",
+ ["tests/metagpt/test_role.py", "test_react", "Input", "goal"],
+ ),
+ ],
+ )
+ def test_split_namespace(self, val, want):
+ res = split_namespace(val)
+ assert res == want
+
def test_read_json_file(self):
assert read_json_file(str(Path(__file__).parent / "../../data/ut_writer/yft_swaggerApi.json"), encoding="utf-8")
with pytest.raises(FileNotFoundError):