diff --git a/.gitignore b/.gitignore
index 1613a638d..cec4b10e4 100644
--- a/.gitignore
+++ b/.gitignore
@@ -171,3 +171,4 @@ tests/metagpt/utils/file_repo_git
*.png
htmlcov
htmlcov.*
+*.dot
diff --git a/metagpt/actions/rebuild_class_view.py b/metagpt/actions/rebuild_class_view.py
index 66bc2c7ab..adc28ff9d 100644
--- a/metagpt/actions/rebuild_class_view.py
+++ b/metagpt/actions/rebuild_class_view.py
@@ -9,52 +9,51 @@
import re
from pathlib import Path
+import aiofiles
+
from metagpt.actions import Action
from metagpt.config import CONFIG
-from metagpt.const import CLASS_VIEW_FILE_REPO, GRAPH_REPO_FILE_REPO
+from metagpt.const import (
+ CLASS_VIEW_FILE_REPO,
+ DATA_API_DESIGN_FILE_REPO,
+ GRAPH_REPO_FILE_REPO,
+)
from metagpt.repo_parser import RepoParser
from metagpt.utils.di_graph_repository import DiGraphRepository
from metagpt.utils.graph_repository import GraphKeyword, GraphRepository
class RebuildClassView(Action):
- def __init__(self, name="", context=None, llm=None):
- super().__init__(name=name, context=context, llm=llm)
-
async def run(self, with_messages=None, format=CONFIG.prompt_schema):
graph_repo_pathname = CONFIG.git_repo.workdir / GRAPH_REPO_FILE_REPO / CONFIG.git_repo.workdir.name
graph_db = await DiGraphRepository.load_from(str(graph_repo_pathname.with_suffix(".json")))
- repo_parser = RepoParser(base_directory=self.context)
- class_views = await repo_parser.rebuild_class_views(path=Path(self.context)) # use pylint
+ repo_parser = RepoParser(base_directory=Path(self.context))
+ class_views, relationship_views = await repo_parser.rebuild_class_views(path=Path(self.context)) # use pylint
await GraphRepository.update_graph_db_with_class_views(graph_db, class_views)
+ await GraphRepository.update_graph_db_with_class_relationship_views(graph_db, relationship_views)
symbols = repo_parser.generate_symbols() # use ast
for file_info in symbols:
await GraphRepository.update_graph_db_with_file_info(graph_db, file_info)
- await self._create_mermaid_class_view(graph_db=graph_db)
+ # await graph_db.save(path=graph_repo_pathname.parent)
+ await self._create_mermaid_class_views(graph_db=graph_db)
await self._save(graph_db=graph_db)
- async def _create_mermaid_class_view(self, graph_db):
- pass
- # dataset = await graph_db.select(subject=concat_namespace(filename, class_name), predicate=GraphKeyword.HAS_PAGE_INFO)
- # if not dataset:
- # logger.warning(f"No page info for {concat_namespace(filename, class_name)}")
- # return
- # code_block_info = CodeBlockInfo.parse_raw(dataset[0].object_)
- # src_code = await read_file_block(filename=Path(self.context) / filename, lineno=code_block_info.lineno, end_lineno=code_block_info.end_lineno)
- # code_type = ""
- # dataset = await graph_db.select(subject=filename, predicate=GraphKeyword.IS)
- # for spo in dataset:
- # if spo.object_ in ["javascript", "python"]:
- # code_type = spo.object_
- # break
+ async def _create_mermaid_class_views(self, graph_db):
+ path = Path(CONFIG.git_repo.workdir) / DATA_API_DESIGN_FILE_REPO
+ path.mkdir(parents=True, exist_ok=True)
+ async with aiofiles.open(str(path / CONFIG.git_repo.workdir.name), mode="w", encoding="utf-8") as writer:
+ await writer.write("classDiagram\n")
+ # class names
+ rows = await graph_db.select(predicate=GraphKeyword.IS, object_=GraphKeyword.CLASS)
+ distinct = {}
+ for r in rows:
+ await RebuildClassView._create_mermaid_class(r, graph_db, writer, distinct)
- # try:
- # node = await REBUILD_CLASS_VIEW_NODE.fill(context=f"```{code_type}\n{src_code}\n```", llm=self.llm, to=format)
- # class_view = node.instruct_content.model_dump()["Class View"]
- # except Exception as e:
- # class_view = RepoParser.rebuild_class_view(src_code, code_type)
- # await graph_db.insert(subject=concat_namespace(filename, class_name), predicate=GraphKeyword.HAS_CLASS_VIEW, object_=class_view)
- # logger.info(f"{concat_namespace(filename, class_name)} {GraphKeyword.HAS_CLASS_VIEW} {class_view}")
+ @staticmethod
+ async def _create_mermaid_class(ns_class_name, graph_db, file_writer, distinct):
+ pass
+ # fields = split_namespace(ns_class_name)
+ # await graph_db.select(subject=ns_class_name)
async def _save(self, graph_db):
class_view_file_repo = CONFIG.git_repo.new_file_repository(relative_path=CLASS_VIEW_FILE_REPO)
diff --git a/metagpt/const.py b/metagpt/const.py
index a57be641b..811ff9516 100644
--- a/metagpt/const.py
+++ b/metagpt/const.py
@@ -126,3 +126,8 @@ LLM_API_TIMEOUT = 300
# Message id
IGNORED_MESSAGE_ID = "0"
+
+# Class Relationship
+GENERALIZATION = "Generalize"
+COMPOSITION = "Composite"
+AGGREGATION = "Aggregate"
diff --git a/metagpt/repo_parser.py b/metagpt/repo_parser.py
index 9f3a1bac4..f4a9a7f3a 100644
--- a/metagpt/repo_parser.py
+++ b/metagpt/repo_parser.py
@@ -13,15 +13,15 @@ import re
import subprocess
from pathlib import Path
from pprint import pformat
-from typing import Dict, List, Optional, Tuple
+from typing import Dict, List, Optional
-import aiofiles
import pandas as pd
from pydantic import BaseModel, Field
from metagpt.config import CONFIG
+from metagpt.const import AGGREGATION, COMPOSITION, GENERALIZATION
from metagpt.logs import logger
-from metagpt.utils.common import any_to_str
+from metagpt.utils.common import any_to_str, aread
from metagpt.utils.exceptions import handle_exception
@@ -48,6 +48,13 @@ class ClassInfo(BaseModel):
methods: Dict[str, str] = Field(default_factory=dict)
+class ClassRelationship(BaseModel):
+ src: str = ""
+ dest: str = ""
+ relationship: str = ""
+ label: Optional[str] = None
+
+
class RepoParser(BaseModel):
base_directory: Path = Field(default=None)
@@ -62,7 +69,8 @@ class RepoParser(BaseModel):
file_info = RepoFileInfo(file=str(file_path.relative_to(self.base_directory)))
for node in tree:
info = RepoParser.node_to_str(node)
- file_info.page_info.append(info)
+ if info:
+ file_info.page_info.append(info)
if isinstance(node, ast.ClassDef):
class_methods = [m.name for m in node.body if is_func(m)]
file_info.classes.append({"name": node.name, "methods": class_methods})
@@ -111,7 +119,9 @@ class RepoParser(BaseModel):
self.generate_dataframe_structure(output_path)
@staticmethod
- def node_to_str(node) -> (int, int, str, str | Tuple):
+ def node_to_str(node) -> CodeBlockInfo | None:
+ if isinstance(node, ast.Try):
+ return None
if any_to_str(node) == any_to_str(ast.Expr):
return CodeBlockInfo(
lineno=node.lineno,
@@ -130,6 +140,7 @@ class RepoParser(BaseModel):
},
any_to_str(ast.If): RepoParser._parse_if,
any_to_str(ast.AsyncFunctionDef): lambda x: x.name,
+ any_to_str(ast.AnnAssign): lambda x: RepoParser._parse_variable(x.target),
}
func = mappings.get(any_to_str(node))
if func:
@@ -165,22 +176,52 @@ class RepoParser(BaseModel):
@staticmethod
def _parse_if(n):
- tokens = [RepoParser._parse_variable(n.test.left)]
- for item in n.test.comparators:
- tokens.append(RepoParser._parse_variable(item))
+ tokens = []
+ try:
+ if isinstance(n.test, ast.BoolOp):
+ tokens = []
+ for v in n.test.values:
+ tokens.extend(RepoParser._parse_if_compare(v))
+ return tokens
+ if isinstance(n.test, ast.Compare):
+ v = RepoParser._parse_variable(n.test.left)
+ if v:
+ tokens.append(v)
+ for item in n.test.comparators:
+ v = RepoParser._parse_variable(item)
+ if v:
+ tokens.append(v)
+ return tokens
+ except Exception as e:
+ logger.warning(e)
return tokens
+ @staticmethod
+ def _parse_if_compare(n):
+ if hasattr(n, "left"):
+ return RepoParser._parse_variable(n.left)
+ else:
+ return []
+
@staticmethod
def _parse_variable(node):
- funcs = {
- any_to_str(ast.Constant): lambda x: x.value,
- any_to_str(ast.Name): lambda x: x.id,
- any_to_str(ast.Attribute): lambda x: f"{x.value.id}.{x.attr}",
- }
- func = funcs.get(any_to_str(node))
- if not func:
- raise NotImplementedError(f"Not implement:{node}")
- return func(node)
+ try:
+ funcs = {
+ any_to_str(ast.Constant): lambda x: x.value,
+ any_to_str(ast.Name): lambda x: x.id,
+ any_to_str(ast.Attribute): lambda x: f"{x.value.id}.{x.attr}"
+ if hasattr(x.value, "id")
+ else f"{x.attr}",
+ any_to_str(ast.Call): lambda x: RepoParser._parse_variable(x.func),
+ any_to_str(ast.Tuple): lambda x: "",
+ }
+ func = funcs.get(any_to_str(node))
+ if not func:
+ raise NotImplementedError(f"Not implement:{node}")
+ return func(node)
+ except Exception as e:
+ logger.warning(e)
+ raise e
@staticmethod
def _parse_assign(node):
@@ -198,18 +239,21 @@ class RepoParser(BaseModel):
raise ValueError(f"{result}")
class_view_pathname = path / "classes.dot"
class_views = await self._parse_classes(class_view_pathname)
+ relationship_views = await self._parse_class_relationships(class_view_pathname)
packages_pathname = path / "packages.dot"
- class_views = RepoParser._repair_namespaces(class_views=class_views, path=path)
+ class_views, relationship_views = RepoParser._repair_namespaces(
+ class_views=class_views, relationship_views=relationship_views, path=path
+ )
class_view_pathname.unlink(missing_ok=True)
packages_pathname.unlink(missing_ok=True)
- return class_views
+ return class_views, relationship_views
async def _parse_classes(self, class_view_pathname):
class_views = []
if not class_view_pathname.exists():
return class_views
- async with aiofiles.open(str(class_view_pathname), mode="r") as reader:
- lines = await reader.readlines()
+ data = await aread(filename=class_view_pathname, encoding="utf-8")
+ lines = data.split("\n")
for line in lines:
package_name, info = RepoParser._split_class_line(line)
if not package_name:
@@ -230,6 +274,19 @@ class RepoParser(BaseModel):
class_views.append(class_info)
return class_views
+ async def _parse_class_relationships(self, class_view_pathname) -> List[ClassRelationShip]:
+ relationship_views = []
+ if not class_view_pathname.exists():
+ return relationship_views
+ data = await aread(filename=class_view_pathname, encoding="utf-8")
+ lines = data.split("\n")
+ for line in lines:
+ relationship = RepoParser._split_relationship_line(line)
+ if not relationship:
+ continue
+ relationship_views.append(relationship)
+ return relationship_views
+
@staticmethod
def _split_class_line(line):
part_splitor = '" ['
@@ -248,6 +305,40 @@ class RepoParser(BaseModel):
info = re.sub(r"
]*>", "\n", info)
return class_name, info
+ @staticmethod
+ def _split_relationship_line(line):
+ splitters = [" -> ", " [", "];"]
+ idxs = []
+ for tag in splitters:
+ if tag not in line:
+ return None
+ idxs.append(line.find(tag))
+ ret = ClassRelationship()
+ ret.src = line[0 : idxs[0]].strip('"')
+ ret.dest = line[idxs[0] + len(splitters[0]) : idxs[1]].strip('"')
+ properties = line[idxs[1] + len(splitters[1]) : idxs[2]].strip(" ")
+ mappings = {
+ 'arrowhead="empty"': GENERALIZATION,
+ 'arrowhead="diamond"': COMPOSITION,
+ 'arrowhead="odiamond"': AGGREGATION,
+ }
+ for k, v in mappings.items():
+ if k in properties:
+ ret.relationship = v
+ if v != GENERALIZATION:
+ ret.label = RepoParser._get_label(properties)
+ break
+ return ret
+
+ @staticmethod
+ def _get_label(line):
+ tag = 'label="'
+ if tag not in line:
+ return ""
+ ix = line.find(tag)
+ eix = line.find('"', ix + len(tag))
+ return line[ix + len(tag) : eix]
+
@staticmethod
def _create_path_mapping(path: str | Path) -> Dict[str, str]:
mappings = {
@@ -272,7 +363,9 @@ class RepoParser(BaseModel):
return mappings
@staticmethod
- def _repair_namespaces(class_views: List[ClassInfo], path: str | Path) -> List[ClassInfo]:
+ def _repair_namespaces(
+ class_views: List[ClassInfo], relationship_views: List[ClassRelationship], path: str | Path
+ ) -> (List[ClassInfo], List[ClassRelationShip]):
if not class_views:
return []
c = class_views[0]
@@ -291,7 +384,12 @@ class RepoParser(BaseModel):
for c in class_views:
c.package = RepoParser._repair_ns(c.package, new_mappings)
- return class_views
+ for i in range(len(relationship_views)):
+ v = relationship_views[i]
+ v.src = RepoParser._repair_ns(v.src, new_mappings)
+ v.dest = RepoParser._repair_ns(v.dest, new_mappings)
+ relationship_views[i] = v
+ return class_views, relationship_views
@staticmethod
def _repair_ns(package, mappings):
diff --git a/metagpt/utils/common.py b/metagpt/utils/common.py
index 5999b2e11..71faff834 100644
--- a/metagpt/utils/common.py
+++ b/metagpt/utils/common.py
@@ -419,6 +419,10 @@ def concat_namespace(*args) -> str:
return ":".join(str(value) for value in args)
+def split_namespace(ns_class_name: str) -> List[str]:
+ pass
+
+
def general_after_log(i: "loguru.Logger", sec_format: str = "%0.3f") -> typing.Callable[["RetryCallState"], None]:
"""
Generates a logging function to be used after a call is retried.
diff --git a/metagpt/utils/di_graph_repository.py b/metagpt/utils/di_graph_repository.py
index 08f4327fa..8bb5f9bb3 100644
--- a/metagpt/utils/di_graph_repository.py
+++ b/metagpt/utils/di_graph_repository.py
@@ -12,9 +12,9 @@ import json
from pathlib import Path
from typing import List
-import aiofiles
import networkx
+from metagpt.utils.common import aread, awrite
from metagpt.utils.graph_repository import SPO, GraphRepository
@@ -55,12 +55,10 @@ class DiGraphRepository(GraphRepository):
if not path.exists():
path.mkdir(parents=True, exist_ok=True)
pathname = Path(path) / self.name
- async with aiofiles.open(str(pathname.with_suffix(".json")), mode="w", encoding="utf-8") as writer:
- await writer.write(data)
+ await awrite(filename=pathname.with_suffix(".json"), data=data, encoding="utf-8")
async def load(self, pathname: str | Path):
- async with aiofiles.open(str(pathname), mode="r", encoding="utf-8") as reader:
- data = await reader.read(-1)
+ data = await aread(filename=pathname, encoding="utf-8")
m = json.loads(data)
self._repo = networkx.node_link_graph(m)
diff --git a/metagpt/utils/graph_repository.py b/metagpt/utils/graph_repository.py
index 37da3dee4..f9eb273f9 100644
--- a/metagpt/utils/graph_repository.py
+++ b/metagpt/utils/graph_repository.py
@@ -13,19 +13,24 @@ from typing import List
from pydantic import BaseModel
-from metagpt.repo_parser import ClassInfo, RepoFileInfo
+from metagpt.repo_parser import ClassInfo, ClassRelationship, RepoFileInfo
from metagpt.utils.common import concat_namespace
class GraphKeyword:
IS = "is"
+ OF = "Of"
+ ON = "On"
CLASS = "class"
FUNCTION = "function"
+ HAS_FUNCTION = "has_function"
SOURCE_CODE = "source_code"
NULL = ""
GLOBAL_VARIABLE = "global_variable"
CLASS_FUNCTION = "class_function"
CLASS_PROPERTY = "class_property"
+ HAS_CLASS_FUNCTION = "has_class_function"
+ HAS_CLASS_PROPERTY = "has_class_property"
HAS_CLASS = "has_class"
HAS_PAGE_INFO = "has_page_info"
HAS_CLASS_VIEW = "has_class_view"
@@ -73,11 +78,13 @@ class GraphRepository(ABC):
await graph_db.insert(subject=file_info.file, predicate=GraphKeyword.IS, object_=file_type)
for c in file_info.classes:
class_name = c.get("name", "")
+ # file -> class
await graph_db.insert(
subject=file_info.file,
predicate=GraphKeyword.HAS_CLASS,
object_=concat_namespace(file_info.file, class_name),
)
+ # class detail
await graph_db.insert(
subject=concat_namespace(file_info.file, class_name),
predicate=GraphKeyword.IS,
@@ -85,12 +92,22 @@ class GraphRepository(ABC):
)
methods = c.get("methods", [])
for fn in methods:
+ await graph_db.insert(
+ subject=concat_namespace(file_info.file, class_name),
+ predicate=GraphKeyword.HAS_CLASS_FUNCTION,
+ object_=concat_namespace(file_info.file, class_name, fn),
+ )
await graph_db.insert(
subject=concat_namespace(file_info.file, class_name, fn),
predicate=GraphKeyword.IS,
object_=GraphKeyword.CLASS_FUNCTION,
)
for f in file_info.functions:
+ # file -> function
+ await graph_db.insert(
+ subject=file_info.file, predicate=GraphKeyword.HAS_FUNCTION, object_=concat_namespace(file_info.file, f)
+ )
+ # function detail
await graph_db.insert(
subject=concat_namespace(file_info.file, f), predicate=GraphKeyword.IS, object_=GraphKeyword.FUNCTION
)
@@ -105,13 +122,13 @@ class GraphRepository(ABC):
await graph_db.insert(
subject=concat_namespace(file_info.file, *code_block.tokens),
predicate=GraphKeyword.HAS_PAGE_INFO,
- object_=code_block.json(ensure_ascii=False),
+ object_=code_block.model_dump_json(),
)
for k, v in code_block.properties.items():
await graph_db.insert(
subject=concat_namespace(file_info.file, k, v),
predicate=GraphKeyword.HAS_PAGE_INFO,
- object_=code_block.json(ensure_ascii=False),
+ object_=code_block.model_dump_json(),
)
@staticmethod
@@ -129,6 +146,13 @@ class GraphRepository(ABC):
object_=GraphKeyword.CLASS,
)
for vn, vt in c.attributes.items():
+ # class -> property
+ await graph_db.insert(
+ subject=c.package,
+ predicate=GraphKeyword.HAS_CLASS_PROPERTY,
+ object_=concat_namespace(c.package, vn),
+ )
+ # property detail
await graph_db.insert(
subject=concat_namespace(c.package, vn),
predicate=GraphKeyword.IS,
@@ -138,6 +162,13 @@ class GraphRepository(ABC):
subject=concat_namespace(c.package, vn), predicate=GraphKeyword.HAS_TYPE_DESC, object_=vt
)
for fn, desc in c.methods.items():
+ # class -> function
+ await graph_db.insert(
+ subject=c.package,
+ predicate=GraphKeyword.HAS_CLASS_FUNCTION,
+ object_=concat_namespace(c.package, fn),
+ )
+ # function detail
await graph_db.insert(
subject=concat_namespace(c.package, fn),
predicate=GraphKeyword.IS,
@@ -148,3 +179,19 @@ class GraphRepository(ABC):
predicate=GraphKeyword.HAS_ARGS_DESC,
object_=desc,
)
+
+ @staticmethod
+ async def update_graph_db_with_class_relationship_views(
+ graph_db: "GraphRepository", relationship_views: List[ClassRelationship]
+ ):
+ for r in relationship_views:
+ await graph_db.insert(
+ subject=r.src, predicate=GraphKeyword.IS + r.relationship + GraphKeyword.OF, object_=r.dest
+ )
+ if not r.label:
+ continue
+ await graph_db.insert(
+ subject=r.src,
+ predicate=GraphKeyword.IS + r.relationship + GraphKeyword.ON,
+ object_=concat_namespace(r.dest, r.label),
+ )
diff --git a/tests/metagpt/actions/test_rebuild_class_view.py b/tests/metagpt/actions/test_rebuild_class_view.py
index 955c6ae3b..941a32a3d 100644
--- a/tests/metagpt/actions/test_rebuild_class_view.py
+++ b/tests/metagpt/actions/test_rebuild_class_view.py
@@ -16,7 +16,9 @@ from metagpt.llm import LLM
@pytest.mark.asyncio
async def test_rebuild():
- action = RebuildClassView(name="RedBean", context=Path(__file__).parent.parent, llm=LLM())
+ action = RebuildClassView(
+ name="RedBean", context=str(Path(__file__).parent.parent.parent.parent / "metagpt"), llm=LLM()
+ )
await action.run()