mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-04-29 19:06:23 +02:00
feat: + class view
feat: +retry
This commit is contained in:
parent
4a5ae1b4a4
commit
78af904f5e
6 changed files with 182 additions and 77 deletions
|
|
@ -8,6 +8,7 @@
|
|||
"""
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import aiofiles
|
||||
|
||||
|
|
@ -29,25 +30,27 @@ from metagpt.utils.graph_repository import GraphKeyword, GraphRepository
|
|||
|
||||
|
||||
class RebuildClassView(Action):
|
||||
graph_db: Optional[GraphRepository] = None
|
||||
|
||||
async def run(self, with_messages=None, format=config.prompt_schema):
|
||||
graph_repo_pathname = self.context.git_repo.workdir / GRAPH_REPO_FILE_REPO / self.context.git_repo.workdir.name
|
||||
graph_db = await DiGraphRepository.load_from(str(graph_repo_pathname.with_suffix(".json")))
|
||||
self.graph_db = await DiGraphRepository.load_from(str(graph_repo_pathname.with_suffix(".json")))
|
||||
repo_parser = RepoParser(base_directory=Path(self.i_context))
|
||||
# use pylint
|
||||
class_views, relationship_views, package_root = await repo_parser.rebuild_class_views(path=Path(self.i_context))
|
||||
await GraphRepository.update_graph_db_with_class_views(graph_db, class_views)
|
||||
await GraphRepository.update_graph_db_with_class_relationship_views(graph_db, relationship_views)
|
||||
await GraphRepository.update_graph_db_with_class_views(self.graph_db, class_views)
|
||||
await GraphRepository.update_graph_db_with_class_relationship_views(self.graph_db, relationship_views)
|
||||
# use ast
|
||||
direction, diff_path = self._diff_path(path_root=Path(self.i_context).resolve(), package_root=package_root)
|
||||
symbols = repo_parser.generate_symbols()
|
||||
for file_info in symbols:
|
||||
# Align to the same root directory in accordance with `class_views`.
|
||||
file_info.file = self._align_root(file_info.file, direction, diff_path)
|
||||
await GraphRepository.update_graph_db_with_file_info(graph_db, file_info)
|
||||
await self._create_mermaid_class_views(graph_db=graph_db)
|
||||
await graph_db.save()
|
||||
await GraphRepository.update_graph_db_with_file_info(self.graph_db, file_info)
|
||||
await self._create_mermaid_class_views()
|
||||
await self.graph_db.save()
|
||||
|
||||
async def _create_mermaid_class_views(self, graph_db):
|
||||
async def _create_mermaid_class_views(self):
|
||||
path = Path(self.context.git_repo.workdir) / DATA_API_DESIGN_FILE_REPO
|
||||
path.mkdir(parents=True, exist_ok=True)
|
||||
pathname = path / self.context.git_repo.workdir.name
|
||||
|
|
@ -56,47 +59,52 @@ class RebuildClassView(Action):
|
|||
logger.debug(content)
|
||||
await writer.write(content)
|
||||
# class names
|
||||
rows = await graph_db.select(predicate=GraphKeyword.IS, object_=GraphKeyword.CLASS)
|
||||
rows = await self.graph_db.select(predicate=GraphKeyword.IS, object_=GraphKeyword.CLASS)
|
||||
class_distinct = set()
|
||||
relationship_distinct = set()
|
||||
for r in rows:
|
||||
await RebuildClassView._create_mermaid_class(r.subject, graph_db, writer, class_distinct)
|
||||
content = await self._create_mermaid_class(r.subject)
|
||||
if content:
|
||||
await writer.write(content)
|
||||
class_distinct.add(r.subject)
|
||||
for r in rows:
|
||||
await RebuildClassView._create_mermaid_relationship(r.subject, graph_db, writer, relationship_distinct)
|
||||
content, distinct = await self._create_mermaid_relationship(r.subject)
|
||||
if content:
|
||||
logger.debug(content)
|
||||
await writer.write(content)
|
||||
relationship_distinct += distinct
|
||||
logger.info(f"classes: {len(class_distinct)}, relationship: {len(relationship_distinct)}")
|
||||
|
||||
@staticmethod
|
||||
async def _create_mermaid_class(ns_class_name, graph_db, file_writer, distinct):
|
||||
async def _create_mermaid_class(self, ns_class_name) -> str:
|
||||
fields = split_namespace(ns_class_name)
|
||||
if len(fields) > 2:
|
||||
# Ignore sub-class
|
||||
return
|
||||
return ""
|
||||
|
||||
class_view = ClassView(name=fields[1])
|
||||
rows = await graph_db.select(subject=ns_class_name)
|
||||
rows = await self.graph_db.select(subject=ns_class_name)
|
||||
for r in rows:
|
||||
name = split_namespace(r.object_)[-1]
|
||||
name, visibility, abstraction = RebuildClassView._parse_name(name=name, language="python")
|
||||
if r.predicate == GraphKeyword.HAS_CLASS_PROPERTY:
|
||||
var_type = await RebuildClassView._parse_variable_type(r.object_, graph_db)
|
||||
var_type = await self._parse_variable_type(r.object_)
|
||||
attribute = ClassAttribute(
|
||||
name=name, visibility=visibility, abstraction=bool(abstraction), value_type=var_type
|
||||
)
|
||||
class_view.attributes.append(attribute)
|
||||
elif r.predicate == GraphKeyword.HAS_CLASS_FUNCTION:
|
||||
method = ClassMethod(name=name, visibility=visibility, abstraction=bool(abstraction))
|
||||
await RebuildClassView._parse_function_args(method, r.object_, graph_db)
|
||||
await self._parse_function_args(method, r.object_)
|
||||
class_view.methods.append(method)
|
||||
|
||||
# update graph db
|
||||
await graph_db.insert(ns_class_name, GraphKeyword.HAS_CLASS_VIEW, class_view.model_dump_json())
|
||||
await self.graph_db.insert(ns_class_name, GraphKeyword.HAS_CLASS_VIEW, class_view.model_dump_json())
|
||||
|
||||
content = class_view.get_mermaid(align=1)
|
||||
logger.debug(content)
|
||||
await file_writer.write(content)
|
||||
distinct.add(ns_class_name)
|
||||
return content
|
||||
|
||||
@staticmethod
|
||||
async def _create_mermaid_relationship(ns_class_name, graph_db, file_writer, distinct):
|
||||
async def _create_mermaid_relationship(self, ns_class_name):
|
||||
s_fields = split_namespace(ns_class_name)
|
||||
if len(s_fields) > 2:
|
||||
# Ignore sub-class
|
||||
|
|
@ -109,8 +117,9 @@ class RebuildClassView(Action):
|
|||
AGGREGATION: " o-- ",
|
||||
}
|
||||
content = ""
|
||||
distinct = set()
|
||||
for p, v in predicates.items():
|
||||
rows = await graph_db.select(subject=ns_class_name, predicate=p)
|
||||
rows = await self.graph_db.select(subject=ns_class_name, predicate=p)
|
||||
for r in rows:
|
||||
o_fields = split_namespace(r.object_)
|
||||
if len(o_fields) > 2:
|
||||
|
|
@ -121,13 +130,11 @@ class RebuildClassView(Action):
|
|||
distinct.add(link)
|
||||
content += f"\t{link}\n"
|
||||
|
||||
if content:
|
||||
logger.debug(content)
|
||||
await file_writer.write(content)
|
||||
return content, distinct
|
||||
|
||||
@staticmethod
|
||||
def _parse_name(name: str, language="python"):
|
||||
pattern = re.compile(r"<I>(.*?)<\/I>")
|
||||
pattern = re.compile(r"<I>(.*?)</I>")
|
||||
result = re.search(pattern, name)
|
||||
|
||||
abstraction = ""
|
||||
|
|
@ -142,9 +149,8 @@ class RebuildClassView(Action):
|
|||
visibility = "+"
|
||||
return name, visibility, abstraction
|
||||
|
||||
@staticmethod
|
||||
async def _parse_variable_type(ns_name, graph_db) -> str:
|
||||
rows = await graph_db.select(subject=ns_name, predicate=GraphKeyword.HAS_TYPE_DESC)
|
||||
async def _parse_variable_type(self, ns_name) -> str:
|
||||
rows = await self.graph_db.select(subject=ns_name, predicate=GraphKeyword.HAS_TYPE_DESC)
|
||||
if not rows:
|
||||
return ""
|
||||
vals = rows[0].object_.replace("'", "").split(":")
|
||||
|
|
@ -153,9 +159,8 @@ class RebuildClassView(Action):
|
|||
val = vals[-1].strip()
|
||||
return "" if val == "NoneType" else val + " "
|
||||
|
||||
@staticmethod
|
||||
async def _parse_function_args(method: ClassMethod, ns_name: str, graph_db: GraphRepository):
|
||||
rows = await graph_db.select(subject=ns_name, predicate=GraphKeyword.HAS_ARGS_DESC)
|
||||
async def _parse_function_args(self, method: ClassMethod, ns_name: str):
|
||||
rows = await self.graph_db.select(subject=ns_name, predicate=GraphKeyword.HAS_ARGS_DESC)
|
||||
if not rows:
|
||||
return
|
||||
info = rows[0].object_.replace("'", "")
|
||||
|
|
|
|||
|
|
@ -8,30 +8,36 @@
|
|||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
from tenacity import retry, stop_after_attempt, wait_random_exponential
|
||||
|
||||
from metagpt.actions import Action
|
||||
from metagpt.config2 import config
|
||||
from metagpt.const import GRAPH_REPO_FILE_REPO
|
||||
from metagpt.logs import logger
|
||||
from metagpt.utils.common import aread, list_files
|
||||
from metagpt.schema import ClassView
|
||||
from metagpt.utils.common import aread, general_after_log, list_files, split_namespace
|
||||
from metagpt.utils.di_graph_repository import DiGraphRepository
|
||||
from metagpt.utils.graph_repository import GraphKeyword
|
||||
from metagpt.utils.graph_repository import SPO, GraphKeyword, GraphRepository
|
||||
|
||||
|
||||
class RebuildSequenceView(Action):
|
||||
graph_db: Optional[GraphRepository] = None
|
||||
|
||||
async def run(self, with_messages=None, format=config.prompt_schema):
|
||||
graph_repo_pathname = self.context.git_repo.workdir / GRAPH_REPO_FILE_REPO / self.context.git_repo.workdir.name
|
||||
graph_db = await DiGraphRepository.load_from(str(graph_repo_pathname.with_suffix(".json")))
|
||||
entries = await RebuildSequenceView._search_main_entry(graph_db)
|
||||
self.graph_db = await DiGraphRepository.load_from(str(graph_repo_pathname.with_suffix(".json")))
|
||||
entries = await self._search_main_entry()
|
||||
for entry in entries:
|
||||
await self._rebuild_sequence_view(entry, graph_db)
|
||||
await graph_db.save()
|
||||
await self._rebuild_sequence_view(entry)
|
||||
await self.graph_db.save()
|
||||
|
||||
@staticmethod
|
||||
async def _search_main_entry(graph_db) -> List:
|
||||
rows = await graph_db.select(predicate=GraphKeyword.HAS_PAGE_INFO)
|
||||
async def _search_main_entry(self) -> List:
|
||||
rows = await self.graph_db.select(predicate=GraphKeyword.HAS_PAGE_INFO)
|
||||
tag = "__name__:__main__"
|
||||
entries = []
|
||||
for r in rows:
|
||||
|
|
@ -39,18 +45,97 @@ class RebuildSequenceView(Action):
|
|||
entries.append(r)
|
||||
return entries
|
||||
|
||||
async def _rebuild_sequence_view(self, entry, graph_db):
|
||||
async def _rebuild_sequence_view(self, entry):
|
||||
filename = entry.subject.split(":", 1)[0]
|
||||
src_filename = RebuildSequenceView._get_full_filename(root=self.i_context, pathname=filename)
|
||||
if not src_filename:
|
||||
return
|
||||
content = await aread(filename=src_filename, encoding="utf-8")
|
||||
content = f"```python\n{content}\n```\n\n---\nTranslate the code above into Mermaid Sequence Diagram."
|
||||
data = await self.llm.aask(
|
||||
sequence_view = await self.llm.aask(
|
||||
msg=content, system_msgs=["You are a python code to Mermaid Sequence Diagram translator in function detail"]
|
||||
)
|
||||
await graph_db.insert(subject=filename, predicate=GraphKeyword.HAS_SEQUENCE_VIEW, object_=data)
|
||||
logger.info(data)
|
||||
await self.graph_db.insert(subject=filename, predicate=GraphKeyword.HAS_SEQUENCE_VIEW, object_=sequence_view)
|
||||
logger.info(sequence_view)
|
||||
|
||||
merged_class_views = set()
|
||||
while True:
|
||||
participants = RebuildSequenceView.parse_participant(sequence_view)
|
||||
class_views = await self._get_class_views(participants)
|
||||
diff = set()
|
||||
for cv in class_views:
|
||||
if cv.subject in merged_class_views:
|
||||
continue
|
||||
class_functionality, class_view = await self._parse_class_functionality(cv)
|
||||
sequence_view = await self._merge_sequence_view(sequence_view, class_view, class_functionality)
|
||||
diff.add(cv.subject)
|
||||
merged_class_views.add(cv.subject)
|
||||
|
||||
await self.graph_db.delete(subject=filename, predicate=GraphKeyword.HAS_SEQUENCE_VIEW)
|
||||
await self.graph_db.insert(
|
||||
subject=filename, predicate=GraphKeyword.HAS_SEQUENCE_VIEW, object_=sequence_view
|
||||
)
|
||||
logger.info(sequence_view)
|
||||
if diff:
|
||||
continue
|
||||
break
|
||||
|
||||
@retry(
|
||||
wait=wait_random_exponential(min=1, max=20),
|
||||
stop=stop_after_attempt(6),
|
||||
after=general_after_log(logger),
|
||||
)
|
||||
async def _parse_class_functionality(self, spo):
|
||||
class_view = ClassView.model_validate_json(spo.object_)
|
||||
class_view_content = f"```mermaid\n{class_view.get_mermaid(align=0)}\n```"
|
||||
rsp = await self.llm.aask(
|
||||
msg=f"## Class View\n{class_view_content}",
|
||||
system_msgs=[
|
||||
"You are a tool capable of translating class views into a textual description of their functionalities and goals.",
|
||||
'Please return a Markdown JSON format with a "description" key containing a concise textual description of the class functionalities, a "goal" key containing a concise textual description of the problem the class aims to solve, and a "reason" key explaining why.',
|
||||
],
|
||||
)
|
||||
|
||||
class _JsonCodeBlock(BaseModel):
|
||||
description: str
|
||||
goal: str
|
||||
reason: Optional[str] = None
|
||||
|
||||
code_block = rsp.removeprefix("```json\n").removesuffix("```")
|
||||
data = _JsonCodeBlock.model_validate_json(code_block)
|
||||
data.reason = None
|
||||
functionality = data.model_dump_json(exclude_none=True)
|
||||
await self.graph_db.insert(subject=spo.subject, predicate=GraphKeyword.HAS_CLASS_DESC, object_=functionality)
|
||||
return functionality, class_view_content
|
||||
|
||||
@retry(
|
||||
wait=wait_random_exponential(min=1, max=20),
|
||||
stop=stop_after_attempt(6),
|
||||
after=general_after_log(logger),
|
||||
)
|
||||
async def _merge_sequence_view(self, sequence_view, class_view, class_functionality) -> str:
|
||||
contents = [
|
||||
f"## Mermaid Class View\n{class_view}",
|
||||
f"## Class Description\n{class_functionality}",
|
||||
f"## Mermaid Sequence View\n{sequence_view}",
|
||||
]
|
||||
msg = "\n---\n".join(contents)
|
||||
rsp = await self.llm.aask(
|
||||
msg=msg,
|
||||
system_msgs=[
|
||||
"You are a tool to merge Mermaid class view information into the Mermaid sequence view.",
|
||||
'Append as much information as possible from the "Mermaid Class View" and "Class Description" to the sequence diagram.',
|
||||
'Return a markdown JSON format with a "sequence_diagram" key containing the merged Mermaid sequence view, a "reason" key explaining what information have been merged and why.',
|
||||
],
|
||||
)
|
||||
|
||||
class _JsonCodeBlock(BaseModel):
|
||||
sequence_diagram: str
|
||||
reason: str
|
||||
|
||||
code_block = rsp.removeprefix("```json\n").removesuffix("```")
|
||||
data = _JsonCodeBlock.model_validate_json(code_block)
|
||||
return data.sequence_diagram
|
||||
|
||||
@staticmethod
|
||||
def _get_full_filename(root: str | Path, pathname: str | Path) -> Path | None:
|
||||
|
|
@ -60,3 +145,26 @@ class RebuildSequenceView(Action):
|
|||
if str(i).endswith(postfix):
|
||||
return i
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def parse_participant(mermaid_sequence_diagram: str) -> List[str]:
|
||||
pattern = r"participant (\w+)"
|
||||
matches = re.findall(pattern, mermaid_sequence_diagram)
|
||||
return matches
|
||||
|
||||
async def _get_class_views(self, class_names) -> List[SPO]:
|
||||
rows = await self.graph_db.select(predicate=GraphKeyword.IS, object_=GraphKeyword.CLASS)
|
||||
|
||||
ns_class_names = {}
|
||||
for r in rows:
|
||||
ns, class_name = split_namespace(r.subject)
|
||||
if class_name in class_names:
|
||||
ns_class_names[r.subject] = class_name
|
||||
|
||||
class_views = []
|
||||
for ns_name in ns_class_names.keys():
|
||||
views = await self.graph_db.select(subject=ns_name, predicate=GraphKeyword.HAS_CLASS_VIEW)
|
||||
if not views:
|
||||
continue
|
||||
class_views += views
|
||||
return class_views
|
||||
|
|
|
|||
|
|
@ -1,16 +0,0 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2024/1/4
|
||||
@Author : mashenquan
|
||||
@File : rebuild_sequence_view_an.py
|
||||
"""
|
||||
from metagpt.actions.action_node import ActionNode
|
||||
from metagpt.utils.mermaid import MMC2
|
||||
|
||||
CODE_2_MERMAID_SEQUENCE_DIAGRAM = ActionNode(
|
||||
key="Program call flow",
|
||||
expected_type=str,
|
||||
instruction='Translate the "context" content into "format example" format.',
|
||||
example=MMC2,
|
||||
)
|
||||
|
|
@ -408,7 +408,7 @@ def concat_namespace(*args) -> str:
|
|||
|
||||
|
||||
def split_namespace(ns_class_name: str) -> List[str]:
|
||||
return ns_class_name.split(":")
|
||||
return ns_class_name.split(":", maxsplit=1)
|
||||
|
||||
|
||||
def general_after_log(i: "loguru.Logger", sec_format: str = "%0.3f") -> typing.Callable[["RetryCallState"], None]:
|
||||
|
|
@ -567,3 +567,8 @@ def list_files(root: str | Path) -> List[Path]:
|
|||
except Exception as e:
|
||||
logger.error(f"Error: {e}")
|
||||
return files
|
||||
|
||||
|
||||
def parse_json_code_block(markdown_text: str) -> List[str]:
|
||||
json_blocks = re.findall(r"```json(.*?)```", markdown_text, re.DOTALL)
|
||||
return [v.strip() for v in json_blocks]
|
||||
|
|
|
|||
|
|
@ -26,12 +26,6 @@ class DiGraphRepository(GraphRepository):
|
|||
async def insert(self, subject: str, predicate: str, object_: str):
|
||||
self._repo.add_edge(subject, object_, predicate=predicate)
|
||||
|
||||
async def upsert(self, subject: str, predicate: str, object_: str):
|
||||
pass
|
||||
|
||||
async def update(self, subject: str, predicate: str, object_: str):
|
||||
pass
|
||||
|
||||
async def select(self, subject: str = None, predicate: str = None, object_: str = None) -> List[SPO]:
|
||||
result = []
|
||||
for s, o, p in self._repo.edges(data="predicate"):
|
||||
|
|
@ -44,6 +38,14 @@ class DiGraphRepository(GraphRepository):
|
|||
result.append(SPO(subject=s, predicate=p, object_=o))
|
||||
return result
|
||||
|
||||
async def delete(self, subject: str = None, predicate: str = None, object_: str = None) -> int:
|
||||
rows = await self.select(subject=subject, predicate=predicate, object_=object_)
|
||||
if not rows:
|
||||
return 0
|
||||
for r in rows:
|
||||
self._repo.remove_edge(r.subject, r.object_)
|
||||
return len(rows)
|
||||
|
||||
def json(self) -> str:
|
||||
m = networkx.node_link_data(self._repo)
|
||||
data = json.dumps(m)
|
||||
|
|
|
|||
|
|
@ -33,6 +33,7 @@ class GraphKeyword:
|
|||
HAS_CLASS_FUNCTION = "has_class_function"
|
||||
HAS_CLASS_PROPERTY = "has_class_property"
|
||||
HAS_CLASS = "has_class"
|
||||
HAS_CLASS_DESC = "has_class_desc"
|
||||
HAS_PAGE_INFO = "has_page_info"
|
||||
HAS_CLASS_VIEW = "has_class_view"
|
||||
HAS_SEQUENCE_VIEW = "has_sequence_view"
|
||||
|
|
@ -55,18 +56,18 @@ class GraphRepository(ABC):
|
|||
async def insert(self, subject: str, predicate: str, object_: str):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def upsert(self, subject: str, predicate: str, object_: str):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def update(self, subject: str, predicate: str, object_: str):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def select(self, subject: str = None, predicate: str = None, object_: str = None) -> List[SPO]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def delete(self, subject: str = None, predicate: str = None, object_: str = None) -> int:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def save(self):
|
||||
pass
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self._repo_name
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue