diff --git a/.gitignore b/.gitignore index 1613a638d..cec4b10e4 100644 --- a/.gitignore +++ b/.gitignore @@ -171,3 +171,4 @@ tests/metagpt/utils/file_repo_git *.png htmlcov htmlcov.* +*.dot diff --git a/metagpt/actions/rebuild_class_view.py b/metagpt/actions/rebuild_class_view.py index 66bc2c7ab..adc28ff9d 100644 --- a/metagpt/actions/rebuild_class_view.py +++ b/metagpt/actions/rebuild_class_view.py @@ -9,52 +9,51 @@ import re from pathlib import Path +import aiofiles + from metagpt.actions import Action from metagpt.config import CONFIG -from metagpt.const import CLASS_VIEW_FILE_REPO, GRAPH_REPO_FILE_REPO +from metagpt.const import ( + CLASS_VIEW_FILE_REPO, + DATA_API_DESIGN_FILE_REPO, + GRAPH_REPO_FILE_REPO, +) from metagpt.repo_parser import RepoParser from metagpt.utils.di_graph_repository import DiGraphRepository from metagpt.utils.graph_repository import GraphKeyword, GraphRepository class RebuildClassView(Action): - def __init__(self, name="", context=None, llm=None): - super().__init__(name=name, context=context, llm=llm) - async def run(self, with_messages=None, format=CONFIG.prompt_schema): graph_repo_pathname = CONFIG.git_repo.workdir / GRAPH_REPO_FILE_REPO / CONFIG.git_repo.workdir.name graph_db = await DiGraphRepository.load_from(str(graph_repo_pathname.with_suffix(".json"))) - repo_parser = RepoParser(base_directory=self.context) - class_views = await repo_parser.rebuild_class_views(path=Path(self.context)) # use pylint + repo_parser = RepoParser(base_directory=Path(self.context)) + class_views, relationship_views = await repo_parser.rebuild_class_views(path=Path(self.context)) # use pylint await GraphRepository.update_graph_db_with_class_views(graph_db, class_views) + await GraphRepository.update_graph_db_with_class_relationship_views(graph_db, relationship_views) symbols = repo_parser.generate_symbols() # use ast for file_info in symbols: await GraphRepository.update_graph_db_with_file_info(graph_db, file_info) - await self._create_mermaid_class_view(graph_db=graph_db) + # await graph_db.save(path=graph_repo_pathname.parent) + await self._create_mermaid_class_views(graph_db=graph_db) await self._save(graph_db=graph_db) - async def _create_mermaid_class_view(self, graph_db): - pass - # dataset = await graph_db.select(subject=concat_namespace(filename, class_name), predicate=GraphKeyword.HAS_PAGE_INFO) - # if not dataset: - # logger.warning(f"No page info for {concat_namespace(filename, class_name)}") - # return - # code_block_info = CodeBlockInfo.parse_raw(dataset[0].object_) - # src_code = await read_file_block(filename=Path(self.context) / filename, lineno=code_block_info.lineno, end_lineno=code_block_info.end_lineno) - # code_type = "" - # dataset = await graph_db.select(subject=filename, predicate=GraphKeyword.IS) - # for spo in dataset: - # if spo.object_ in ["javascript", "python"]: - # code_type = spo.object_ - # break + async def _create_mermaid_class_views(self, graph_db): + path = Path(CONFIG.git_repo.workdir) / DATA_API_DESIGN_FILE_REPO + path.mkdir(parents=True, exist_ok=True) + async with aiofiles.open(str(path / CONFIG.git_repo.workdir.name), mode="w", encoding="utf-8") as writer: + await writer.write("classDiagram\n") + # class names + rows = await graph_db.select(predicate=GraphKeyword.IS, object_=GraphKeyword.CLASS) + distinct = {} + for r in rows: + await RebuildClassView._create_mermaid_class(r, graph_db, writer, distinct) - # try: - # node = await REBUILD_CLASS_VIEW_NODE.fill(context=f"```{code_type}\n{src_code}\n```", llm=self.llm, to=format) - # class_view = node.instruct_content.model_dump()["Class View"] - # except Exception as e: - # class_view = RepoParser.rebuild_class_view(src_code, code_type) - # await graph_db.insert(subject=concat_namespace(filename, class_name), predicate=GraphKeyword.HAS_CLASS_VIEW, object_=class_view) - # logger.info(f"{concat_namespace(filename, class_name)} {GraphKeyword.HAS_CLASS_VIEW} {class_view}") + @staticmethod + async def _create_mermaid_class(ns_class_name, graph_db, file_writer, distinct): + pass + # fields = split_namespace(ns_class_name) + # await graph_db.select(subject=ns_class_name) async def _save(self, graph_db): class_view_file_repo = CONFIG.git_repo.new_file_repository(relative_path=CLASS_VIEW_FILE_REPO) diff --git a/metagpt/const.py b/metagpt/const.py index a57be641b..811ff9516 100644 --- a/metagpt/const.py +++ b/metagpt/const.py @@ -126,3 +126,8 @@ LLM_API_TIMEOUT = 300 # Message id IGNORED_MESSAGE_ID = "0" + +# Class Relationship +GENERALIZATION = "Generalize" +COMPOSITION = "Composite" +AGGREGATION = "Aggregate" diff --git a/metagpt/repo_parser.py b/metagpt/repo_parser.py index 9f3a1bac4..f4a9a7f3a 100644 --- a/metagpt/repo_parser.py +++ b/metagpt/repo_parser.py @@ -13,15 +13,15 @@ import re import subprocess from pathlib import Path from pprint import pformat -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Optional -import aiofiles import pandas as pd from pydantic import BaseModel, Field from metagpt.config import CONFIG +from metagpt.const import AGGREGATION, COMPOSITION, GENERALIZATION from metagpt.logs import logger -from metagpt.utils.common import any_to_str +from metagpt.utils.common import any_to_str, aread from metagpt.utils.exceptions import handle_exception @@ -48,6 +48,13 @@ class ClassInfo(BaseModel): methods: Dict[str, str] = Field(default_factory=dict) +class ClassRelationship(BaseModel): + src: str = "" + dest: str = "" + relationship: str = "" + label: Optional[str] = None + + class RepoParser(BaseModel): base_directory: Path = Field(default=None) @@ -62,7 +69,8 @@ class RepoParser(BaseModel): file_info = RepoFileInfo(file=str(file_path.relative_to(self.base_directory))) for node in tree: info = RepoParser.node_to_str(node) - file_info.page_info.append(info) + if info: + file_info.page_info.append(info) if isinstance(node, ast.ClassDef): class_methods = [m.name for m in node.body if is_func(m)] file_info.classes.append({"name": node.name, "methods": class_methods}) @@ -111,7 +119,9 @@ class RepoParser(BaseModel): self.generate_dataframe_structure(output_path) @staticmethod - def node_to_str(node) -> (int, int, str, str | Tuple): + def node_to_str(node) -> CodeBlockInfo | None: + if isinstance(node, ast.Try): + return None if any_to_str(node) == any_to_str(ast.Expr): return CodeBlockInfo( lineno=node.lineno, @@ -130,6 +140,7 @@ class RepoParser(BaseModel): }, any_to_str(ast.If): RepoParser._parse_if, any_to_str(ast.AsyncFunctionDef): lambda x: x.name, + any_to_str(ast.AnnAssign): lambda x: RepoParser._parse_variable(x.target), } func = mappings.get(any_to_str(node)) if func: @@ -165,22 +176,52 @@ class RepoParser(BaseModel): @staticmethod def _parse_if(n): - tokens = [RepoParser._parse_variable(n.test.left)] - for item in n.test.comparators: - tokens.append(RepoParser._parse_variable(item)) + tokens = [] + try: + if isinstance(n.test, ast.BoolOp): + tokens = [] + for v in n.test.values: + tokens.extend(RepoParser._parse_if_compare(v)) + return tokens + if isinstance(n.test, ast.Compare): + v = RepoParser._parse_variable(n.test.left) + if v: + tokens.append(v) + for item in n.test.comparators: + v = RepoParser._parse_variable(item) + if v: + tokens.append(v) + return tokens + except Exception as e: + logger.warning(e) return tokens + @staticmethod + def _parse_if_compare(n): + if hasattr(n, "left"): + return RepoParser._parse_variable(n.left) + else: + return [] + @staticmethod def _parse_variable(node): - funcs = { - any_to_str(ast.Constant): lambda x: x.value, - any_to_str(ast.Name): lambda x: x.id, - any_to_str(ast.Attribute): lambda x: f"{x.value.id}.{x.attr}", - } - func = funcs.get(any_to_str(node)) - if not func: - raise NotImplementedError(f"Not implement:{node}") - return func(node) + try: + funcs = { + any_to_str(ast.Constant): lambda x: x.value, + any_to_str(ast.Name): lambda x: x.id, + any_to_str(ast.Attribute): lambda x: f"{x.value.id}.{x.attr}" + if hasattr(x.value, "id") + else f"{x.attr}", + any_to_str(ast.Call): lambda x: RepoParser._parse_variable(x.func), + any_to_str(ast.Tuple): lambda x: "", + } + func = funcs.get(any_to_str(node)) + if not func: + raise NotImplementedError(f"Not implement:{node}") + return func(node) + except Exception as e: + logger.warning(e) + raise e @staticmethod def _parse_assign(node): @@ -198,18 +239,21 @@ class RepoParser(BaseModel): raise ValueError(f"{result}") class_view_pathname = path / "classes.dot" class_views = await self._parse_classes(class_view_pathname) + relationship_views = await self._parse_class_relationships(class_view_pathname) packages_pathname = path / "packages.dot" - class_views = RepoParser._repair_namespaces(class_views=class_views, path=path) + class_views, relationship_views = RepoParser._repair_namespaces( + class_views=class_views, relationship_views=relationship_views, path=path + ) class_view_pathname.unlink(missing_ok=True) packages_pathname.unlink(missing_ok=True) - return class_views + return class_views, relationship_views async def _parse_classes(self, class_view_pathname): class_views = [] if not class_view_pathname.exists(): return class_views - async with aiofiles.open(str(class_view_pathname), mode="r") as reader: - lines = await reader.readlines() + data = await aread(filename=class_view_pathname, encoding="utf-8") + lines = data.split("\n") for line in lines: package_name, info = RepoParser._split_class_line(line) if not package_name: @@ -230,6 +274,19 @@ class RepoParser(BaseModel): class_views.append(class_info) return class_views + async def _parse_class_relationships(self, class_view_pathname) -> List[ClassRelationShip]: + relationship_views = [] + if not class_view_pathname.exists(): + return relationship_views + data = await aread(filename=class_view_pathname, encoding="utf-8") + lines = data.split("\n") + for line in lines: + relationship = RepoParser._split_relationship_line(line) + if not relationship: + continue + relationship_views.append(relationship) + return relationship_views + @staticmethod def _split_class_line(line): part_splitor = '" [' @@ -248,6 +305,40 @@ class RepoParser(BaseModel): info = re.sub(r"]*>", "\n", info) return class_name, info + @staticmethod + def _split_relationship_line(line): + splitters = [" -> ", " [", "];"] + idxs = [] + for tag in splitters: + if tag not in line: + return None + idxs.append(line.find(tag)) + ret = ClassRelationship() + ret.src = line[0 : idxs[0]].strip('"') + ret.dest = line[idxs[0] + len(splitters[0]) : idxs[1]].strip('"') + properties = line[idxs[1] + len(splitters[1]) : idxs[2]].strip(" ") + mappings = { + 'arrowhead="empty"': GENERALIZATION, + 'arrowhead="diamond"': COMPOSITION, + 'arrowhead="odiamond"': AGGREGATION, + } + for k, v in mappings.items(): + if k in properties: + ret.relationship = v + if v != GENERALIZATION: + ret.label = RepoParser._get_label(properties) + break + return ret + + @staticmethod + def _get_label(line): + tag = 'label="' + if tag not in line: + return "" + ix = line.find(tag) + eix = line.find('"', ix + len(tag)) + return line[ix + len(tag) : eix] + @staticmethod def _create_path_mapping(path: str | Path) -> Dict[str, str]: mappings = { @@ -272,7 +363,9 @@ class RepoParser(BaseModel): return mappings @staticmethod - def _repair_namespaces(class_views: List[ClassInfo], path: str | Path) -> List[ClassInfo]: + def _repair_namespaces( + class_views: List[ClassInfo], relationship_views: List[ClassRelationship], path: str | Path + ) -> (List[ClassInfo], List[ClassRelationShip]): if not class_views: return [] c = class_views[0] @@ -291,7 +384,12 @@ class RepoParser(BaseModel): for c in class_views: c.package = RepoParser._repair_ns(c.package, new_mappings) - return class_views + for i in range(len(relationship_views)): + v = relationship_views[i] + v.src = RepoParser._repair_ns(v.src, new_mappings) + v.dest = RepoParser._repair_ns(v.dest, new_mappings) + relationship_views[i] = v + return class_views, relationship_views @staticmethod def _repair_ns(package, mappings): diff --git a/metagpt/utils/common.py b/metagpt/utils/common.py index 5999b2e11..71faff834 100644 --- a/metagpt/utils/common.py +++ b/metagpt/utils/common.py @@ -419,6 +419,10 @@ def concat_namespace(*args) -> str: return ":".join(str(value) for value in args) +def split_namespace(ns_class_name: str) -> List[str]: + pass + + def general_after_log(i: "loguru.Logger", sec_format: str = "%0.3f") -> typing.Callable[["RetryCallState"], None]: """ Generates a logging function to be used after a call is retried. diff --git a/metagpt/utils/di_graph_repository.py b/metagpt/utils/di_graph_repository.py index 08f4327fa..8bb5f9bb3 100644 --- a/metagpt/utils/di_graph_repository.py +++ b/metagpt/utils/di_graph_repository.py @@ -12,9 +12,9 @@ import json from pathlib import Path from typing import List -import aiofiles import networkx +from metagpt.utils.common import aread, awrite from metagpt.utils.graph_repository import SPO, GraphRepository @@ -55,12 +55,10 @@ class DiGraphRepository(GraphRepository): if not path.exists(): path.mkdir(parents=True, exist_ok=True) pathname = Path(path) / self.name - async with aiofiles.open(str(pathname.with_suffix(".json")), mode="w", encoding="utf-8") as writer: - await writer.write(data) + await awrite(filename=pathname.with_suffix(".json"), data=data, encoding="utf-8") async def load(self, pathname: str | Path): - async with aiofiles.open(str(pathname), mode="r", encoding="utf-8") as reader: - data = await reader.read(-1) + data = await aread(filename=pathname, encoding="utf-8") m = json.loads(data) self._repo = networkx.node_link_graph(m) diff --git a/metagpt/utils/graph_repository.py b/metagpt/utils/graph_repository.py index 37da3dee4..f9eb273f9 100644 --- a/metagpt/utils/graph_repository.py +++ b/metagpt/utils/graph_repository.py @@ -13,19 +13,24 @@ from typing import List from pydantic import BaseModel -from metagpt.repo_parser import ClassInfo, RepoFileInfo +from metagpt.repo_parser import ClassInfo, ClassRelationship, RepoFileInfo from metagpt.utils.common import concat_namespace class GraphKeyword: IS = "is" + OF = "Of" + ON = "On" CLASS = "class" FUNCTION = "function" + HAS_FUNCTION = "has_function" SOURCE_CODE = "source_code" NULL = "" GLOBAL_VARIABLE = "global_variable" CLASS_FUNCTION = "class_function" CLASS_PROPERTY = "class_property" + HAS_CLASS_FUNCTION = "has_class_function" + HAS_CLASS_PROPERTY = "has_class_property" HAS_CLASS = "has_class" HAS_PAGE_INFO = "has_page_info" HAS_CLASS_VIEW = "has_class_view" @@ -73,11 +78,13 @@ class GraphRepository(ABC): await graph_db.insert(subject=file_info.file, predicate=GraphKeyword.IS, object_=file_type) for c in file_info.classes: class_name = c.get("name", "") + # file -> class await graph_db.insert( subject=file_info.file, predicate=GraphKeyword.HAS_CLASS, object_=concat_namespace(file_info.file, class_name), ) + # class detail await graph_db.insert( subject=concat_namespace(file_info.file, class_name), predicate=GraphKeyword.IS, @@ -85,12 +92,22 @@ class GraphRepository(ABC): ) methods = c.get("methods", []) for fn in methods: + await graph_db.insert( + subject=concat_namespace(file_info.file, class_name), + predicate=GraphKeyword.HAS_CLASS_FUNCTION, + object_=concat_namespace(file_info.file, class_name, fn), + ) await graph_db.insert( subject=concat_namespace(file_info.file, class_name, fn), predicate=GraphKeyword.IS, object_=GraphKeyword.CLASS_FUNCTION, ) for f in file_info.functions: + # file -> function + await graph_db.insert( + subject=file_info.file, predicate=GraphKeyword.HAS_FUNCTION, object_=concat_namespace(file_info.file, f) + ) + # function detail await graph_db.insert( subject=concat_namespace(file_info.file, f), predicate=GraphKeyword.IS, object_=GraphKeyword.FUNCTION ) @@ -105,13 +122,13 @@ class GraphRepository(ABC): await graph_db.insert( subject=concat_namespace(file_info.file, *code_block.tokens), predicate=GraphKeyword.HAS_PAGE_INFO, - object_=code_block.json(ensure_ascii=False), + object_=code_block.model_dump_json(), ) for k, v in code_block.properties.items(): await graph_db.insert( subject=concat_namespace(file_info.file, k, v), predicate=GraphKeyword.HAS_PAGE_INFO, - object_=code_block.json(ensure_ascii=False), + object_=code_block.model_dump_json(), ) @staticmethod @@ -129,6 +146,13 @@ class GraphRepository(ABC): object_=GraphKeyword.CLASS, ) for vn, vt in c.attributes.items(): + # class -> property + await graph_db.insert( + subject=c.package, + predicate=GraphKeyword.HAS_CLASS_PROPERTY, + object_=concat_namespace(c.package, vn), + ) + # property detail await graph_db.insert( subject=concat_namespace(c.package, vn), predicate=GraphKeyword.IS, @@ -138,6 +162,13 @@ class GraphRepository(ABC): subject=concat_namespace(c.package, vn), predicate=GraphKeyword.HAS_TYPE_DESC, object_=vt ) for fn, desc in c.methods.items(): + # class -> function + await graph_db.insert( + subject=c.package, + predicate=GraphKeyword.HAS_CLASS_FUNCTION, + object_=concat_namespace(c.package, fn), + ) + # function detail await graph_db.insert( subject=concat_namespace(c.package, fn), predicate=GraphKeyword.IS, @@ -148,3 +179,19 @@ class GraphRepository(ABC): predicate=GraphKeyword.HAS_ARGS_DESC, object_=desc, ) + + @staticmethod + async def update_graph_db_with_class_relationship_views( + graph_db: "GraphRepository", relationship_views: List[ClassRelationship] + ): + for r in relationship_views: + await graph_db.insert( + subject=r.src, predicate=GraphKeyword.IS + r.relationship + GraphKeyword.OF, object_=r.dest + ) + if not r.label: + continue + await graph_db.insert( + subject=r.src, + predicate=GraphKeyword.IS + r.relationship + GraphKeyword.ON, + object_=concat_namespace(r.dest, r.label), + ) diff --git a/tests/metagpt/actions/test_rebuild_class_view.py b/tests/metagpt/actions/test_rebuild_class_view.py index 955c6ae3b..941a32a3d 100644 --- a/tests/metagpt/actions/test_rebuild_class_view.py +++ b/tests/metagpt/actions/test_rebuild_class_view.py @@ -16,7 +16,9 @@ from metagpt.llm import LLM @pytest.mark.asyncio async def test_rebuild(): - action = RebuildClassView(name="RedBean", context=Path(__file__).parent.parent, llm=LLM()) + action = RebuildClassView( + name="RedBean", context=str(Path(__file__).parent.parent.parent.parent / "metagpt"), llm=LLM() + ) await action.run()