feat: + class view

feat: +retry
This commit is contained in:
莘权 马 2024-01-18 22:33:04 +08:00
parent 4a5ae1b4a4
commit 78af904f5e
6 changed files with 182 additions and 77 deletions

View file

@ -8,6 +8,7 @@
"""
import re
from pathlib import Path
from typing import Optional
import aiofiles
@ -29,25 +30,27 @@ from metagpt.utils.graph_repository import GraphKeyword, GraphRepository
class RebuildClassView(Action):
graph_db: Optional[GraphRepository] = None
async def run(self, with_messages=None, format=config.prompt_schema):
graph_repo_pathname = self.context.git_repo.workdir / GRAPH_REPO_FILE_REPO / self.context.git_repo.workdir.name
graph_db = await DiGraphRepository.load_from(str(graph_repo_pathname.with_suffix(".json")))
self.graph_db = await DiGraphRepository.load_from(str(graph_repo_pathname.with_suffix(".json")))
repo_parser = RepoParser(base_directory=Path(self.i_context))
# use pylint
class_views, relationship_views, package_root = await repo_parser.rebuild_class_views(path=Path(self.i_context))
await GraphRepository.update_graph_db_with_class_views(graph_db, class_views)
await GraphRepository.update_graph_db_with_class_relationship_views(graph_db, relationship_views)
await GraphRepository.update_graph_db_with_class_views(self.graph_db, class_views)
await GraphRepository.update_graph_db_with_class_relationship_views(self.graph_db, relationship_views)
# use ast
direction, diff_path = self._diff_path(path_root=Path(self.i_context).resolve(), package_root=package_root)
symbols = repo_parser.generate_symbols()
for file_info in symbols:
# Align to the same root directory in accordance with `class_views`.
file_info.file = self._align_root(file_info.file, direction, diff_path)
await GraphRepository.update_graph_db_with_file_info(graph_db, file_info)
await self._create_mermaid_class_views(graph_db=graph_db)
await graph_db.save()
await GraphRepository.update_graph_db_with_file_info(self.graph_db, file_info)
await self._create_mermaid_class_views()
await self.graph_db.save()
async def _create_mermaid_class_views(self, graph_db):
async def _create_mermaid_class_views(self):
path = Path(self.context.git_repo.workdir) / DATA_API_DESIGN_FILE_REPO
path.mkdir(parents=True, exist_ok=True)
pathname = path / self.context.git_repo.workdir.name
@ -56,47 +59,52 @@ class RebuildClassView(Action):
logger.debug(content)
await writer.write(content)
# class names
rows = await graph_db.select(predicate=GraphKeyword.IS, object_=GraphKeyword.CLASS)
rows = await self.graph_db.select(predicate=GraphKeyword.IS, object_=GraphKeyword.CLASS)
class_distinct = set()
relationship_distinct = set()
for r in rows:
await RebuildClassView._create_mermaid_class(r.subject, graph_db, writer, class_distinct)
content = await self._create_mermaid_class(r.subject)
if content:
await writer.write(content)
class_distinct.add(r.subject)
for r in rows:
await RebuildClassView._create_mermaid_relationship(r.subject, graph_db, writer, relationship_distinct)
content, distinct = await self._create_mermaid_relationship(r.subject)
if content:
logger.debug(content)
await writer.write(content)
relationship_distinct += distinct
logger.info(f"classes: {len(class_distinct)}, relationship: {len(relationship_distinct)}")
@staticmethod
async def _create_mermaid_class(ns_class_name, graph_db, file_writer, distinct):
async def _create_mermaid_class(self, ns_class_name) -> str:
fields = split_namespace(ns_class_name)
if len(fields) > 2:
# Ignore sub-class
return
return ""
class_view = ClassView(name=fields[1])
rows = await graph_db.select(subject=ns_class_name)
rows = await self.graph_db.select(subject=ns_class_name)
for r in rows:
name = split_namespace(r.object_)[-1]
name, visibility, abstraction = RebuildClassView._parse_name(name=name, language="python")
if r.predicate == GraphKeyword.HAS_CLASS_PROPERTY:
var_type = await RebuildClassView._parse_variable_type(r.object_, graph_db)
var_type = await self._parse_variable_type(r.object_)
attribute = ClassAttribute(
name=name, visibility=visibility, abstraction=bool(abstraction), value_type=var_type
)
class_view.attributes.append(attribute)
elif r.predicate == GraphKeyword.HAS_CLASS_FUNCTION:
method = ClassMethod(name=name, visibility=visibility, abstraction=bool(abstraction))
await RebuildClassView._parse_function_args(method, r.object_, graph_db)
await self._parse_function_args(method, r.object_)
class_view.methods.append(method)
# update graph db
await graph_db.insert(ns_class_name, GraphKeyword.HAS_CLASS_VIEW, class_view.model_dump_json())
await self.graph_db.insert(ns_class_name, GraphKeyword.HAS_CLASS_VIEW, class_view.model_dump_json())
content = class_view.get_mermaid(align=1)
logger.debug(content)
await file_writer.write(content)
distinct.add(ns_class_name)
return content
@staticmethod
async def _create_mermaid_relationship(ns_class_name, graph_db, file_writer, distinct):
async def _create_mermaid_relationship(self, ns_class_name):
s_fields = split_namespace(ns_class_name)
if len(s_fields) > 2:
# Ignore sub-class
@ -109,8 +117,9 @@ class RebuildClassView(Action):
AGGREGATION: " o-- ",
}
content = ""
distinct = set()
for p, v in predicates.items():
rows = await graph_db.select(subject=ns_class_name, predicate=p)
rows = await self.graph_db.select(subject=ns_class_name, predicate=p)
for r in rows:
o_fields = split_namespace(r.object_)
if len(o_fields) > 2:
@ -121,13 +130,11 @@ class RebuildClassView(Action):
distinct.add(link)
content += f"\t{link}\n"
if content:
logger.debug(content)
await file_writer.write(content)
return content, distinct
@staticmethod
def _parse_name(name: str, language="python"):
pattern = re.compile(r"<I>(.*?)<\/I>")
pattern = re.compile(r"<I>(.*?)</I>")
result = re.search(pattern, name)
abstraction = ""
@ -142,9 +149,8 @@ class RebuildClassView(Action):
visibility = "+"
return name, visibility, abstraction
@staticmethod
async def _parse_variable_type(ns_name, graph_db) -> str:
rows = await graph_db.select(subject=ns_name, predicate=GraphKeyword.HAS_TYPE_DESC)
async def _parse_variable_type(self, ns_name) -> str:
rows = await self.graph_db.select(subject=ns_name, predicate=GraphKeyword.HAS_TYPE_DESC)
if not rows:
return ""
vals = rows[0].object_.replace("'", "").split(":")
@ -153,9 +159,8 @@ class RebuildClassView(Action):
val = vals[-1].strip()
return "" if val == "NoneType" else val + " "
@staticmethod
async def _parse_function_args(method: ClassMethod, ns_name: str, graph_db: GraphRepository):
rows = await graph_db.select(subject=ns_name, predicate=GraphKeyword.HAS_ARGS_DESC)
async def _parse_function_args(self, method: ClassMethod, ns_name: str):
rows = await self.graph_db.select(subject=ns_name, predicate=GraphKeyword.HAS_ARGS_DESC)
if not rows:
return
info = rows[0].object_.replace("'", "")

View file

@ -8,30 +8,36 @@
"""
from __future__ import annotations
import re
from pathlib import Path
from typing import List
from typing import List, Optional
from pydantic import BaseModel
from tenacity import retry, stop_after_attempt, wait_random_exponential
from metagpt.actions import Action
from metagpt.config2 import config
from metagpt.const import GRAPH_REPO_FILE_REPO
from metagpt.logs import logger
from metagpt.utils.common import aread, list_files
from metagpt.schema import ClassView
from metagpt.utils.common import aread, general_after_log, list_files, split_namespace
from metagpt.utils.di_graph_repository import DiGraphRepository
from metagpt.utils.graph_repository import GraphKeyword
from metagpt.utils.graph_repository import SPO, GraphKeyword, GraphRepository
class RebuildSequenceView(Action):
graph_db: Optional[GraphRepository] = None
async def run(self, with_messages=None, format=config.prompt_schema):
graph_repo_pathname = self.context.git_repo.workdir / GRAPH_REPO_FILE_REPO / self.context.git_repo.workdir.name
graph_db = await DiGraphRepository.load_from(str(graph_repo_pathname.with_suffix(".json")))
entries = await RebuildSequenceView._search_main_entry(graph_db)
self.graph_db = await DiGraphRepository.load_from(str(graph_repo_pathname.with_suffix(".json")))
entries = await self._search_main_entry()
for entry in entries:
await self._rebuild_sequence_view(entry, graph_db)
await graph_db.save()
await self._rebuild_sequence_view(entry)
await self.graph_db.save()
@staticmethod
async def _search_main_entry(graph_db) -> List:
rows = await graph_db.select(predicate=GraphKeyword.HAS_PAGE_INFO)
async def _search_main_entry(self) -> List:
rows = await self.graph_db.select(predicate=GraphKeyword.HAS_PAGE_INFO)
tag = "__name__:__main__"
entries = []
for r in rows:
@ -39,18 +45,97 @@ class RebuildSequenceView(Action):
entries.append(r)
return entries
async def _rebuild_sequence_view(self, entry, graph_db):
async def _rebuild_sequence_view(self, entry):
filename = entry.subject.split(":", 1)[0]
src_filename = RebuildSequenceView._get_full_filename(root=self.i_context, pathname=filename)
if not src_filename:
return
content = await aread(filename=src_filename, encoding="utf-8")
content = f"```python\n{content}\n```\n\n---\nTranslate the code above into Mermaid Sequence Diagram."
data = await self.llm.aask(
sequence_view = await self.llm.aask(
msg=content, system_msgs=["You are a python code to Mermaid Sequence Diagram translator in function detail"]
)
await graph_db.insert(subject=filename, predicate=GraphKeyword.HAS_SEQUENCE_VIEW, object_=data)
logger.info(data)
await self.graph_db.insert(subject=filename, predicate=GraphKeyword.HAS_SEQUENCE_VIEW, object_=sequence_view)
logger.info(sequence_view)
merged_class_views = set()
while True:
participants = RebuildSequenceView.parse_participant(sequence_view)
class_views = await self._get_class_views(participants)
diff = set()
for cv in class_views:
if cv.subject in merged_class_views:
continue
class_functionality, class_view = await self._parse_class_functionality(cv)
sequence_view = await self._merge_sequence_view(sequence_view, class_view, class_functionality)
diff.add(cv.subject)
merged_class_views.add(cv.subject)
await self.graph_db.delete(subject=filename, predicate=GraphKeyword.HAS_SEQUENCE_VIEW)
await self.graph_db.insert(
subject=filename, predicate=GraphKeyword.HAS_SEQUENCE_VIEW, object_=sequence_view
)
logger.info(sequence_view)
if diff:
continue
break
@retry(
wait=wait_random_exponential(min=1, max=20),
stop=stop_after_attempt(6),
after=general_after_log(logger),
)
async def _parse_class_functionality(self, spo):
class_view = ClassView.model_validate_json(spo.object_)
class_view_content = f"```mermaid\n{class_view.get_mermaid(align=0)}\n```"
rsp = await self.llm.aask(
msg=f"## Class View\n{class_view_content}",
system_msgs=[
"You are a tool capable of translating class views into a textual description of their functionalities and goals.",
'Please return a Markdown JSON format with a "description" key containing a concise textual description of the class functionalities, a "goal" key containing a concise textual description of the problem the class aims to solve, and a "reason" key explaining why.',
],
)
class _JsonCodeBlock(BaseModel):
description: str
goal: str
reason: Optional[str] = None
code_block = rsp.removeprefix("```json\n").removesuffix("```")
data = _JsonCodeBlock.model_validate_json(code_block)
data.reason = None
functionality = data.model_dump_json(exclude_none=True)
await self.graph_db.insert(subject=spo.subject, predicate=GraphKeyword.HAS_CLASS_DESC, object_=functionality)
return functionality, class_view_content
@retry(
wait=wait_random_exponential(min=1, max=20),
stop=stop_after_attempt(6),
after=general_after_log(logger),
)
async def _merge_sequence_view(self, sequence_view, class_view, class_functionality) -> str:
contents = [
f"## Mermaid Class View\n{class_view}",
f"## Class Description\n{class_functionality}",
f"## Mermaid Sequence View\n{sequence_view}",
]
msg = "\n---\n".join(contents)
rsp = await self.llm.aask(
msg=msg,
system_msgs=[
"You are a tool to merge Mermaid class view information into the Mermaid sequence view.",
'Append as much information as possible from the "Mermaid Class View" and "Class Description" to the sequence diagram.',
'Return a markdown JSON format with a "sequence_diagram" key containing the merged Mermaid sequence view, a "reason" key explaining what information have been merged and why.',
],
)
class _JsonCodeBlock(BaseModel):
sequence_diagram: str
reason: str
code_block = rsp.removeprefix("```json\n").removesuffix("```")
data = _JsonCodeBlock.model_validate_json(code_block)
return data.sequence_diagram
@staticmethod
def _get_full_filename(root: str | Path, pathname: str | Path) -> Path | None:
@ -60,3 +145,26 @@ class RebuildSequenceView(Action):
if str(i).endswith(postfix):
return i
return None
@staticmethod
def parse_participant(mermaid_sequence_diagram: str) -> List[str]:
pattern = r"participant (\w+)"
matches = re.findall(pattern, mermaid_sequence_diagram)
return matches
async def _get_class_views(self, class_names) -> List[SPO]:
rows = await self.graph_db.select(predicate=GraphKeyword.IS, object_=GraphKeyword.CLASS)
ns_class_names = {}
for r in rows:
ns, class_name = split_namespace(r.subject)
if class_name in class_names:
ns_class_names[r.subject] = class_name
class_views = []
for ns_name in ns_class_names.keys():
views = await self.graph_db.select(subject=ns_name, predicate=GraphKeyword.HAS_CLASS_VIEW)
if not views:
continue
class_views += views
return class_views

View file

@ -1,16 +0,0 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2024/1/4
@Author : mashenquan
@File : rebuild_sequence_view_an.py
"""
from metagpt.actions.action_node import ActionNode
from metagpt.utils.mermaid import MMC2
CODE_2_MERMAID_SEQUENCE_DIAGRAM = ActionNode(
key="Program call flow",
expected_type=str,
instruction='Translate the "context" content into "format example" format.',
example=MMC2,
)

View file

@ -408,7 +408,7 @@ def concat_namespace(*args) -> str:
def split_namespace(ns_class_name: str) -> List[str]:
return ns_class_name.split(":")
return ns_class_name.split(":", maxsplit=1)
def general_after_log(i: "loguru.Logger", sec_format: str = "%0.3f") -> typing.Callable[["RetryCallState"], None]:
@ -567,3 +567,8 @@ def list_files(root: str | Path) -> List[Path]:
except Exception as e:
logger.error(f"Error: {e}")
return files
def parse_json_code_block(markdown_text: str) -> List[str]:
json_blocks = re.findall(r"```json(.*?)```", markdown_text, re.DOTALL)
return [v.strip() for v in json_blocks]

View file

@ -26,12 +26,6 @@ class DiGraphRepository(GraphRepository):
async def insert(self, subject: str, predicate: str, object_: str):
self._repo.add_edge(subject, object_, predicate=predicate)
async def upsert(self, subject: str, predicate: str, object_: str):
pass
async def update(self, subject: str, predicate: str, object_: str):
pass
async def select(self, subject: str = None, predicate: str = None, object_: str = None) -> List[SPO]:
result = []
for s, o, p in self._repo.edges(data="predicate"):
@ -44,6 +38,14 @@ class DiGraphRepository(GraphRepository):
result.append(SPO(subject=s, predicate=p, object_=o))
return result
async def delete(self, subject: str = None, predicate: str = None, object_: str = None) -> int:
rows = await self.select(subject=subject, predicate=predicate, object_=object_)
if not rows:
return 0
for r in rows:
self._repo.remove_edge(r.subject, r.object_)
return len(rows)
def json(self) -> str:
m = networkx.node_link_data(self._repo)
data = json.dumps(m)

View file

@ -33,6 +33,7 @@ class GraphKeyword:
HAS_CLASS_FUNCTION = "has_class_function"
HAS_CLASS_PROPERTY = "has_class_property"
HAS_CLASS = "has_class"
HAS_CLASS_DESC = "has_class_desc"
HAS_PAGE_INFO = "has_page_info"
HAS_CLASS_VIEW = "has_class_view"
HAS_SEQUENCE_VIEW = "has_sequence_view"
@ -55,18 +56,18 @@ class GraphRepository(ABC):
async def insert(self, subject: str, predicate: str, object_: str):
pass
@abstractmethod
async def upsert(self, subject: str, predicate: str, object_: str):
pass
@abstractmethod
async def update(self, subject: str, predicate: str, object_: str):
pass
@abstractmethod
async def select(self, subject: str = None, predicate: str = None, object_: str = None) -> List[SPO]:
pass
@abstractmethod
async def delete(self, subject: str = None, predicate: str = None, object_: str = None) -> int:
pass
@abstractmethod
async def save(self):
pass
@property
def name(self) -> str:
return self._repo_name