feat: refactor class description

This commit is contained in:
莘权 马 2024-01-19 15:54:14 +08:00
parent 78af904f5e
commit 7be58b07b7

View file

@ -8,6 +8,7 @@
"""
from __future__ import annotations
import json
import re
from pathlib import Path
from typing import List, Optional
@ -17,7 +18,7 @@ from tenacity import retry, stop_after_attempt, wait_random_exponential
from metagpt.actions import Action
from metagpt.config2 import config
from metagpt.const import GRAPH_REPO_FILE_REPO
from metagpt.const import GENERALIZATION, GRAPH_REPO_FILE_REPO
from metagpt.logs import logger
from metagpt.schema import ClassView
from metagpt.utils.common import aread, general_after_log, list_files, split_namespace
@ -66,8 +67,8 @@ class RebuildSequenceView(Action):
for cv in class_views:
if cv.subject in merged_class_views:
continue
class_functionality, class_view = await self._parse_class_functionality(cv)
sequence_view = await self._merge_sequence_view(sequence_view, class_view, class_functionality)
await self._parse_class_description(cv.subject)
sequence_view = await self._merge_sequence_view(sequence_view, cv.subject)
diff.add(cv.subject)
merged_class_views.add(cv.subject)
@ -85,14 +86,63 @@ class RebuildSequenceView(Action):
stop=stop_after_attempt(6),
after=general_after_log(logger),
)
async def _parse_class_functionality(self, spo):
class_view = ClassView.model_validate_json(spo.object_)
class_view_content = f"```mermaid\n{class_view.get_mermaid(align=0)}\n```"
async def _parse_class_description(self, ns_class_name: str):
rows = await self.graph_db.select(subject=ns_class_name, predicate=GraphKeyword.HAS_CLASS_DESC)
if rows:
return
me_class_views = await self.graph_db.select(subject=ns_class_name, predicate=GraphKeyword.HAS_CLASS_VIEW)
if not me_class_views:
# Loss of necessary information to create the description.
await self.graph_db.insert(subject=ns_class_name, predicate=GraphKeyword.HAS_CLASS_VIEW, object_="")
return
# prepare base-class description
rows = await self.graph_db.select(
subject=ns_class_name, predicate=GraphKeyword.IS + GENERALIZATION + GraphKeyword.OF
)
ns_base_class_names = [r.object_ for r in rows]
ns_base_class_views = {}
ns_base_class_descs = {}
for name in ns_base_class_names:
class_views = await self.graph_db.select(subject=name, predicate=GraphKeyword.HAS_CLASS_VIEW)
if rows:
ns_base_class_views[name] = class_views
descs = await self.graph_db.select(subject=name, predicate=GraphKeyword.HAS_CLASS_DESC)
if not descs:
# Haven't been parsed before.
await self._parse_class_description(ns_class_name=name)
descs = await self.graph_db.select(subject=name, predicate=GraphKeyword.HAS_CLASS_DESC)
ns_base_class_descs[name] = descs
# parse class description
prompt = "```mermaid\nclassDiagram\n"
# - add base-class description
for ns_name in ns_base_class_names:
descs = ns_base_class_descs.get(ns_name, [])
for r in descs:
notes = self._desc_to_note(r.object_)
ns, name = split_namespace(r.subject)
for n in notes:
prompt += f'\n\tnote for {name} "{n}"'
views = ns_base_class_views.get(ns_name, [])
for r in views:
cv = ClassView.model_validate_json(r.object_)
prompt += "\n" + cv.get_mermaid(align=1)
# - add relationship
_, me = split_namespace(ns_class_name)
for ns_name in ns_base_class_names:
ns, base = split_namespace(ns_name)
prompt += f"\n\t{base} <|-- {me}"
# - add me
cv = ClassView.model_validate_json(me_class_views[0].object_)
prompt += "\n" + cv.get_mermaid(align=1)
prompt += "\n```"
rsp = await self.llm.aask(
msg=f"## Class View\n{class_view_content}",
msg=prompt,
system_msgs=[
"You are a tool capable of translating class views into a textual description of their functionalities and goals.",
'Please return a Markdown JSON format with a "description" key containing a concise textual description of the class functionalities, a "goal" key containing a concise textual description of the problem the class aims to solve, and a "reason" key explaining why.',
f'Please return a Markdown JSON format with a "description" key containing a concise textual description of the `{me}` class functionalities, a "goal" key containing a concise textual description of the problem the `{me}` class aims to solve, and a "reason" key explaining why.',
],
)
@ -105,23 +155,32 @@ class RebuildSequenceView(Action):
data = _JsonCodeBlock.model_validate_json(code_block)
data.reason = None
functionality = data.model_dump_json(exclude_none=True)
await self.graph_db.insert(subject=spo.subject, predicate=GraphKeyword.HAS_CLASS_DESC, object_=functionality)
return functionality, class_view_content
await self.graph_db.insert(subject=ns_class_name, predicate=GraphKeyword.HAS_CLASS_DESC, object_=functionality)
@retry(
wait=wait_random_exponential(min=1, max=20),
stop=stop_after_attempt(6),
after=general_after_log(logger),
)
async def _merge_sequence_view(self, sequence_view, class_view, class_functionality) -> str:
async def _merge_sequence_view(self, sequence_view, ns_class_name) -> str:
class_view_part = "```mermaid\n"
class_view_rows = await self.graph_db.select(subject=ns_class_name, predicate=GraphKeyword.HAS_CLASS_VIEW)
class_view_part += (
ClassView.model_validate_json(class_view_rows[0].object_).get_mermaid() if class_view_rows else ""
)
_, name = split_namespace(class_view_rows[0].subject)
class_desc_rows = await self.graph_db.select(subject=ns_class_name, predicate=GraphKeyword.HAS_CLASS_DESC)
for r in self._desc_to_note(class_desc_rows[0].object_):
class_view_part += f'\n\tnote for {name} "{r}"'
class_view_part += "\n```"
contents = [
f"## Mermaid Class View\n{class_view}",
f"## Class Description\n{class_functionality}",
f"## Mermaid Class View\n{class_view_part}\n",
f"## Mermaid Sequence View\n{sequence_view}",
]
msg = "\n---\n".join(contents)
prompt = "\n---\n".join(contents)
rsp = await self.llm.aask(
msg=msg,
msg=prompt,
system_msgs=[
"You are a tool to merge Mermaid class view information into the Mermaid sequence view.",
'Append as much information as possible from the "Mermaid Class View" and "Class Description" to the sequence diagram.',
@ -168,3 +227,10 @@ class RebuildSequenceView(Action):
continue
class_views += views
return class_views
@staticmethod
def _desc_to_note(json_str) -> List[str]:
if not json_str:
return []
m = json.loads(json_str)
return [s.replace('"', '\\"').replace("\n", "\\n") for s in m.values()]