feat: + composition, aggregation relationship

This commit is contained in:
莘权 马 2024-01-19 17:24:05 +08:00
parent 7be58b07b7
commit 831ddb1736

View file

@ -18,7 +18,7 @@ from tenacity import retry, stop_after_attempt, wait_random_exponential
from metagpt.actions import Action
from metagpt.config2 import config
from metagpt.const import GENERALIZATION, GRAPH_REPO_FILE_REPO
from metagpt.const import AGGREGATION, COMPOSITION, GENERALIZATION, GRAPH_REPO_FILE_REPO
from metagpt.logs import logger
from metagpt.schema import ClassView
from metagpt.utils.common import aread, general_after_log, list_files, split_namespace
@ -164,14 +164,28 @@ class RebuildSequenceView(Action):
)
async def _merge_sequence_view(self, sequence_view, ns_class_name) -> str:
class_view_part = "```mermaid\n"
class_view_rows = await self.graph_db.select(subject=ns_class_name, predicate=GraphKeyword.HAS_CLASS_VIEW)
class_view_part += (
ClassView.model_validate_json(class_view_rows[0].object_).get_mermaid() if class_view_rows else ""
# add class
class_view_part += await self._class_view_to_mermaid(ns_class_name)
# add aggregation relationship
_, me = split_namespace(ns_class_name)
rows = await self.graph_db.select(
subject=ns_class_name, predicate=GraphKeyword.IS + AGGREGATION + GraphKeyword.OF
)
_, name = split_namespace(class_view_rows[0].subject)
class_desc_rows = await self.graph_db.select(subject=ns_class_name, predicate=GraphKeyword.HAS_CLASS_DESC)
for r in self._desc_to_note(class_desc_rows[0].object_):
class_view_part += f'\n\tnote for {name} "{r}"'
aggregation = [r.object_ for r in rows]
for ns_aggr_name in aggregation:
_, name = split_namespace(ns_aggr_name)
class_view_part += f"\n\t{me} *-- {name}"
class_view_part += await self._class_view_to_mermaid(ns_aggr_name)
# add composition relationship
rows = await self.graph_db.select(
subject=ns_class_name, predicate=GraphKeyword.IS + COMPOSITION + GraphKeyword.OF
)
compositions = [r.object_ for r in rows]
for ns_comp_name in compositions:
_, name = split_namespace(ns_comp_name)
class_view_part += f"\n\t{me} *-- {name}"
class_view_part += await self._class_view_to_mermaid(ns_comp_name)
class_view_part += "\n```"
contents = [
@ -183,7 +197,7 @@ class RebuildSequenceView(Action):
msg=prompt,
system_msgs=[
"You are a tool to merge Mermaid class view information into the Mermaid sequence view.",
'Append as much information as possible from the "Mermaid Class View" and "Class Description" to the sequence diagram.',
'Append as much information as possible from the "Mermaid Class View" to the sequence diagram.',
'Return a markdown JSON format with a "sequence_diagram" key containing the merged Mermaid sequence view, a "reason" key explaining what information have been merged and why.',
],
)
@ -234,3 +248,16 @@ class RebuildSequenceView(Action):
return []
m = json.loads(json_str)
return [s.replace('"', '\\"').replace("\n", "\\n") for s in m.values()]
async def _class_view_to_mermaid(self, ns_class_name):
class_view_rows = await self.graph_db.select(subject=ns_class_name, predicate=GraphKeyword.HAS_CLASS_VIEW)
result = ClassView.model_validate_json(class_view_rows[0].object_).get_mermaid() if class_view_rows else ""
_, name = split_namespace(ns_class_name)
class_desc_rows = await self.graph_db.select(subject=ns_class_name, predicate=GraphKeyword.HAS_CLASS_DESC)
if not class_desc_rows:
# Haven't been parsed before.
await self._parse_class_description(ns_class_name=ns_class_name)
class_desc_rows = await self.graph_db.select(subject=ns_class_name, predicate=GraphKeyword.HAS_CLASS_DESC)
for r in self._desc_to_note(class_desc_rows[0].object_):
result += f'\n\tnote for {name} "{r}"'
return result