diff --git a/metagpt/actions/rebuild_class_view.py b/metagpt/actions/rebuild_class_view.py index 2140ad874..2e27d37fc 100644 --- a/metagpt/actions/rebuild_class_view.py +++ b/metagpt/actions/rebuild_class_view.py @@ -8,6 +8,7 @@ """ import re from pathlib import Path +from typing import Optional import aiofiles @@ -29,25 +30,27 @@ from metagpt.utils.graph_repository import GraphKeyword, GraphRepository class RebuildClassView(Action): + graph_db: Optional[GraphRepository] = None + async def run(self, with_messages=None, format=config.prompt_schema): graph_repo_pathname = self.context.git_repo.workdir / GRAPH_REPO_FILE_REPO / self.context.git_repo.workdir.name - graph_db = await DiGraphRepository.load_from(str(graph_repo_pathname.with_suffix(".json"))) + self.graph_db = await DiGraphRepository.load_from(str(graph_repo_pathname.with_suffix(".json"))) repo_parser = RepoParser(base_directory=Path(self.i_context)) # use pylint class_views, relationship_views, package_root = await repo_parser.rebuild_class_views(path=Path(self.i_context)) - await GraphRepository.update_graph_db_with_class_views(graph_db, class_views) - await GraphRepository.update_graph_db_with_class_relationship_views(graph_db, relationship_views) + await GraphRepository.update_graph_db_with_class_views(self.graph_db, class_views) + await GraphRepository.update_graph_db_with_class_relationship_views(self.graph_db, relationship_views) # use ast direction, diff_path = self._diff_path(path_root=Path(self.i_context).resolve(), package_root=package_root) symbols = repo_parser.generate_symbols() for file_info in symbols: # Align to the same root directory in accordance with `class_views`. file_info.file = self._align_root(file_info.file, direction, diff_path) - await GraphRepository.update_graph_db_with_file_info(graph_db, file_info) - await self._create_mermaid_class_views(graph_db=graph_db) - await graph_db.save() + await GraphRepository.update_graph_db_with_file_info(self.graph_db, file_info) + await self._create_mermaid_class_views() + await self.graph_db.save() - async def _create_mermaid_class_views(self, graph_db): + async def _create_mermaid_class_views(self): path = Path(self.context.git_repo.workdir) / DATA_API_DESIGN_FILE_REPO path.mkdir(parents=True, exist_ok=True) pathname = path / self.context.git_repo.workdir.name @@ -56,47 +59,52 @@ class RebuildClassView(Action): logger.debug(content) await writer.write(content) # class names - rows = await graph_db.select(predicate=GraphKeyword.IS, object_=GraphKeyword.CLASS) + rows = await self.graph_db.select(predicate=GraphKeyword.IS, object_=GraphKeyword.CLASS) class_distinct = set() relationship_distinct = set() for r in rows: - await RebuildClassView._create_mermaid_class(r.subject, graph_db, writer, class_distinct) + content = await self._create_mermaid_class(r.subject) + if content: + await writer.write(content) + class_distinct.add(r.subject) for r in rows: - await RebuildClassView._create_mermaid_relationship(r.subject, graph_db, writer, relationship_distinct) + content, distinct = await self._create_mermaid_relationship(r.subject) + if content: + logger.debug(content) + await writer.write(content) + relationship_distinct += distinct + logger.info(f"classes: {len(class_distinct)}, relationship: {len(relationship_distinct)}") - @staticmethod - async def _create_mermaid_class(ns_class_name, graph_db, file_writer, distinct): + async def _create_mermaid_class(self, ns_class_name) -> str: fields = split_namespace(ns_class_name) if len(fields) > 2: # Ignore sub-class - return + return "" class_view = ClassView(name=fields[1]) - rows = await graph_db.select(subject=ns_class_name) + rows = await self.graph_db.select(subject=ns_class_name) for r in rows: name = split_namespace(r.object_)[-1] name, visibility, abstraction = RebuildClassView._parse_name(name=name, language="python") if r.predicate == GraphKeyword.HAS_CLASS_PROPERTY: - var_type = await RebuildClassView._parse_variable_type(r.object_, graph_db) + var_type = await self._parse_variable_type(r.object_) attribute = ClassAttribute( name=name, visibility=visibility, abstraction=bool(abstraction), value_type=var_type ) class_view.attributes.append(attribute) elif r.predicate == GraphKeyword.HAS_CLASS_FUNCTION: method = ClassMethod(name=name, visibility=visibility, abstraction=bool(abstraction)) - await RebuildClassView._parse_function_args(method, r.object_, graph_db) + await self._parse_function_args(method, r.object_) class_view.methods.append(method) # update graph db - await graph_db.insert(ns_class_name, GraphKeyword.HAS_CLASS_VIEW, class_view.model_dump_json()) + await self.graph_db.insert(ns_class_name, GraphKeyword.HAS_CLASS_VIEW, class_view.model_dump_json()) content = class_view.get_mermaid(align=1) logger.debug(content) - await file_writer.write(content) - distinct.add(ns_class_name) + return content - @staticmethod - async def _create_mermaid_relationship(ns_class_name, graph_db, file_writer, distinct): + async def _create_mermaid_relationship(self, ns_class_name): s_fields = split_namespace(ns_class_name) if len(s_fields) > 2: # Ignore sub-class @@ -109,8 +117,9 @@ class RebuildClassView(Action): AGGREGATION: " o-- ", } content = "" + distinct = set() for p, v in predicates.items(): - rows = await graph_db.select(subject=ns_class_name, predicate=p) + rows = await self.graph_db.select(subject=ns_class_name, predicate=p) for r in rows: o_fields = split_namespace(r.object_) if len(o_fields) > 2: @@ -121,13 +130,11 @@ class RebuildClassView(Action): distinct.add(link) content += f"\t{link}\n" - if content: - logger.debug(content) - await file_writer.write(content) + return content, distinct @staticmethod def _parse_name(name: str, language="python"): - pattern = re.compile(r"(.*?)<\/I>") + pattern = re.compile(r"(.*?)") result = re.search(pattern, name) abstraction = "" @@ -142,9 +149,8 @@ class RebuildClassView(Action): visibility = "+" return name, visibility, abstraction - @staticmethod - async def _parse_variable_type(ns_name, graph_db) -> str: - rows = await graph_db.select(subject=ns_name, predicate=GraphKeyword.HAS_TYPE_DESC) + async def _parse_variable_type(self, ns_name) -> str: + rows = await self.graph_db.select(subject=ns_name, predicate=GraphKeyword.HAS_TYPE_DESC) if not rows: return "" vals = rows[0].object_.replace("'", "").split(":") @@ -153,9 +159,8 @@ class RebuildClassView(Action): val = vals[-1].strip() return "" if val == "NoneType" else val + " " - @staticmethod - async def _parse_function_args(method: ClassMethod, ns_name: str, graph_db: GraphRepository): - rows = await graph_db.select(subject=ns_name, predicate=GraphKeyword.HAS_ARGS_DESC) + async def _parse_function_args(self, method: ClassMethod, ns_name: str): + rows = await self.graph_db.select(subject=ns_name, predicate=GraphKeyword.HAS_ARGS_DESC) if not rows: return info = rows[0].object_.replace("'", "") diff --git a/metagpt/actions/rebuild_sequence_view.py b/metagpt/actions/rebuild_sequence_view.py index 777dde8ce..c0afd239f 100644 --- a/metagpt/actions/rebuild_sequence_view.py +++ b/metagpt/actions/rebuild_sequence_view.py @@ -8,30 +8,36 @@ """ from __future__ import annotations +import re from pathlib import Path -from typing import List +from typing import List, Optional + +from pydantic import BaseModel +from tenacity import retry, stop_after_attempt, wait_random_exponential from metagpt.actions import Action from metagpt.config2 import config from metagpt.const import GRAPH_REPO_FILE_REPO from metagpt.logs import logger -from metagpt.utils.common import aread, list_files +from metagpt.schema import ClassView +from metagpt.utils.common import aread, general_after_log, list_files, split_namespace from metagpt.utils.di_graph_repository import DiGraphRepository -from metagpt.utils.graph_repository import GraphKeyword +from metagpt.utils.graph_repository import SPO, GraphKeyword, GraphRepository class RebuildSequenceView(Action): + graph_db: Optional[GraphRepository] = None + async def run(self, with_messages=None, format=config.prompt_schema): graph_repo_pathname = self.context.git_repo.workdir / GRAPH_REPO_FILE_REPO / self.context.git_repo.workdir.name - graph_db = await DiGraphRepository.load_from(str(graph_repo_pathname.with_suffix(".json"))) - entries = await RebuildSequenceView._search_main_entry(graph_db) + self.graph_db = await DiGraphRepository.load_from(str(graph_repo_pathname.with_suffix(".json"))) + entries = await self._search_main_entry() for entry in entries: - await self._rebuild_sequence_view(entry, graph_db) - await graph_db.save() + await self._rebuild_sequence_view(entry) + await self.graph_db.save() - @staticmethod - async def _search_main_entry(graph_db) -> List: - rows = await graph_db.select(predicate=GraphKeyword.HAS_PAGE_INFO) + async def _search_main_entry(self) -> List: + rows = await self.graph_db.select(predicate=GraphKeyword.HAS_PAGE_INFO) tag = "__name__:__main__" entries = [] for r in rows: @@ -39,18 +45,97 @@ class RebuildSequenceView(Action): entries.append(r) return entries - async def _rebuild_sequence_view(self, entry, graph_db): + async def _rebuild_sequence_view(self, entry): filename = entry.subject.split(":", 1)[0] src_filename = RebuildSequenceView._get_full_filename(root=self.i_context, pathname=filename) if not src_filename: return content = await aread(filename=src_filename, encoding="utf-8") content = f"```python\n{content}\n```\n\n---\nTranslate the code above into Mermaid Sequence Diagram." - data = await self.llm.aask( + sequence_view = await self.llm.aask( msg=content, system_msgs=["You are a python code to Mermaid Sequence Diagram translator in function detail"] ) - await graph_db.insert(subject=filename, predicate=GraphKeyword.HAS_SEQUENCE_VIEW, object_=data) - logger.info(data) + await self.graph_db.insert(subject=filename, predicate=GraphKeyword.HAS_SEQUENCE_VIEW, object_=sequence_view) + logger.info(sequence_view) + + merged_class_views = set() + while True: + participants = RebuildSequenceView.parse_participant(sequence_view) + class_views = await self._get_class_views(participants) + diff = set() + for cv in class_views: + if cv.subject in merged_class_views: + continue + class_functionality, class_view = await self._parse_class_functionality(cv) + sequence_view = await self._merge_sequence_view(sequence_view, class_view, class_functionality) + diff.add(cv.subject) + merged_class_views.add(cv.subject) + + await self.graph_db.delete(subject=filename, predicate=GraphKeyword.HAS_SEQUENCE_VIEW) + await self.graph_db.insert( + subject=filename, predicate=GraphKeyword.HAS_SEQUENCE_VIEW, object_=sequence_view + ) + logger.info(sequence_view) + if diff: + continue + break + + @retry( + wait=wait_random_exponential(min=1, max=20), + stop=stop_after_attempt(6), + after=general_after_log(logger), + ) + async def _parse_class_functionality(self, spo): + class_view = ClassView.model_validate_json(spo.object_) + class_view_content = f"```mermaid\n{class_view.get_mermaid(align=0)}\n```" + rsp = await self.llm.aask( + msg=f"## Class View\n{class_view_content}", + system_msgs=[ + "You are a tool capable of translating class views into a textual description of their functionalities and goals.", + 'Please return a Markdown JSON format with a "description" key containing a concise textual description of the class functionalities, a "goal" key containing a concise textual description of the problem the class aims to solve, and a "reason" key explaining why.', + ], + ) + + class _JsonCodeBlock(BaseModel): + description: str + goal: str + reason: Optional[str] = None + + code_block = rsp.removeprefix("```json\n").removesuffix("```") + data = _JsonCodeBlock.model_validate_json(code_block) + data.reason = None + functionality = data.model_dump_json(exclude_none=True) + await self.graph_db.insert(subject=spo.subject, predicate=GraphKeyword.HAS_CLASS_DESC, object_=functionality) + return functionality, class_view_content + + @retry( + wait=wait_random_exponential(min=1, max=20), + stop=stop_after_attempt(6), + after=general_after_log(logger), + ) + async def _merge_sequence_view(self, sequence_view, class_view, class_functionality) -> str: + contents = [ + f"## Mermaid Class View\n{class_view}", + f"## Class Description\n{class_functionality}", + f"## Mermaid Sequence View\n{sequence_view}", + ] + msg = "\n---\n".join(contents) + rsp = await self.llm.aask( + msg=msg, + system_msgs=[ + "You are a tool to merge Mermaid class view information into the Mermaid sequence view.", + 'Append as much information as possible from the "Mermaid Class View" and "Class Description" to the sequence diagram.', + 'Return a markdown JSON format with a "sequence_diagram" key containing the merged Mermaid sequence view, a "reason" key explaining what information have been merged and why.', + ], + ) + + class _JsonCodeBlock(BaseModel): + sequence_diagram: str + reason: str + + code_block = rsp.removeprefix("```json\n").removesuffix("```") + data = _JsonCodeBlock.model_validate_json(code_block) + return data.sequence_diagram @staticmethod def _get_full_filename(root: str | Path, pathname: str | Path) -> Path | None: @@ -60,3 +145,26 @@ class RebuildSequenceView(Action): if str(i).endswith(postfix): return i return None + + @staticmethod + def parse_participant(mermaid_sequence_diagram: str) -> List[str]: + pattern = r"participant (\w+)" + matches = re.findall(pattern, mermaid_sequence_diagram) + return matches + + async def _get_class_views(self, class_names) -> List[SPO]: + rows = await self.graph_db.select(predicate=GraphKeyword.IS, object_=GraphKeyword.CLASS) + + ns_class_names = {} + for r in rows: + ns, class_name = split_namespace(r.subject) + if class_name in class_names: + ns_class_names[r.subject] = class_name + + class_views = [] + for ns_name in ns_class_names.keys(): + views = await self.graph_db.select(subject=ns_name, predicate=GraphKeyword.HAS_CLASS_VIEW) + if not views: + continue + class_views += views + return class_views diff --git a/metagpt/actions/rebuild_sequence_view_an.py b/metagpt/actions/rebuild_sequence_view_an.py deleted file mode 100644 index f16431510..000000000 --- a/metagpt/actions/rebuild_sequence_view_an.py +++ /dev/null @@ -1,16 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -""" -@Time : 2024/1/4 -@Author : mashenquan -@File : rebuild_sequence_view_an.py -""" -from metagpt.actions.action_node import ActionNode -from metagpt.utils.mermaid import MMC2 - -CODE_2_MERMAID_SEQUENCE_DIAGRAM = ActionNode( - key="Program call flow", - expected_type=str, - instruction='Translate the "context" content into "format example" format.', - example=MMC2, -) diff --git a/metagpt/utils/common.py b/metagpt/utils/common.py index e09d49d84..879fe2637 100644 --- a/metagpt/utils/common.py +++ b/metagpt/utils/common.py @@ -408,7 +408,7 @@ def concat_namespace(*args) -> str: def split_namespace(ns_class_name: str) -> List[str]: - return ns_class_name.split(":") + return ns_class_name.split(":", maxsplit=1) def general_after_log(i: "loguru.Logger", sec_format: str = "%0.3f") -> typing.Callable[["RetryCallState"], None]: @@ -567,3 +567,8 @@ def list_files(root: str | Path) -> List[Path]: except Exception as e: logger.error(f"Error: {e}") return files + + +def parse_json_code_block(markdown_text: str) -> List[str]: + json_blocks = re.findall(r"```json(.*?)```", markdown_text, re.DOTALL) + return [v.strip() for v in json_blocks] diff --git a/metagpt/utils/di_graph_repository.py b/metagpt/utils/di_graph_repository.py index 8bb5f9bb3..b73946cd2 100644 --- a/metagpt/utils/di_graph_repository.py +++ b/metagpt/utils/di_graph_repository.py @@ -26,12 +26,6 @@ class DiGraphRepository(GraphRepository): async def insert(self, subject: str, predicate: str, object_: str): self._repo.add_edge(subject, object_, predicate=predicate) - async def upsert(self, subject: str, predicate: str, object_: str): - pass - - async def update(self, subject: str, predicate: str, object_: str): - pass - async def select(self, subject: str = None, predicate: str = None, object_: str = None) -> List[SPO]: result = [] for s, o, p in self._repo.edges(data="predicate"): @@ -44,6 +38,14 @@ class DiGraphRepository(GraphRepository): result.append(SPO(subject=s, predicate=p, object_=o)) return result + async def delete(self, subject: str = None, predicate: str = None, object_: str = None) -> int: + rows = await self.select(subject=subject, predicate=predicate, object_=object_) + if not rows: + return 0 + for r in rows: + self._repo.remove_edge(r.subject, r.object_) + return len(rows) + def json(self) -> str: m = networkx.node_link_data(self._repo) data = json.dumps(m) diff --git a/metagpt/utils/graph_repository.py b/metagpt/utils/graph_repository.py index 1a6f29a6b..16264cad2 100644 --- a/metagpt/utils/graph_repository.py +++ b/metagpt/utils/graph_repository.py @@ -33,6 +33,7 @@ class GraphKeyword: HAS_CLASS_FUNCTION = "has_class_function" HAS_CLASS_PROPERTY = "has_class_property" HAS_CLASS = "has_class" + HAS_CLASS_DESC = "has_class_desc" HAS_PAGE_INFO = "has_page_info" HAS_CLASS_VIEW = "has_class_view" HAS_SEQUENCE_VIEW = "has_sequence_view" @@ -55,18 +56,18 @@ class GraphRepository(ABC): async def insert(self, subject: str, predicate: str, object_: str): pass - @abstractmethod - async def upsert(self, subject: str, predicate: str, object_: str): - pass - - @abstractmethod - async def update(self, subject: str, predicate: str, object_: str): - pass - @abstractmethod async def select(self, subject: str = None, predicate: str = None, object_: str = None) -> List[SPO]: pass + @abstractmethod + async def delete(self, subject: str = None, predicate: str = None, object_: str = None) -> int: + pass + + @abstractmethod + async def save(self): + pass + @property def name(self) -> str: return self._repo_name