feat: +source -> use case -> sequence view

This commit is contained in:
莘权 马 2024-01-26 19:39:06 +08:00
parent 633c772529
commit 67bf89996b
5 changed files with 251 additions and 206 deletions

View file

@ -23,7 +23,7 @@ from metagpt.const import (
)
from metagpt.logs import logger
from metagpt.repo_parser import DotClassInfo, RepoParser
from metagpt.schema import UMLClassAttribute, UMLClassMethod, UMLClassView
from metagpt.schema import UMLClassView
from metagpt.utils.common import concat_namespace, split_namespace
from metagpt.utils.di_graph_repository import DiGraphRepository
from metagpt.utils.graph_repository import GraphKeyword, GraphRepository
@ -86,18 +86,7 @@ class RebuildClassView(Action):
if not rows:
return ""
dot_class_info = DotClassInfo.model_validate_json(rows[0].object_)
visibility = UMLClassView.name_to_visibility(dot_class_info.name)
class_view = UMLClassView(name=dot_class_info.name, visibility=visibility)
for i in dot_class_info.attributes.values():
visibility = UMLClassAttribute.name_to_visibility(i.name)
attr = UMLClassAttribute(name=i.name, visibility=visibility, value_type=i.type_, default_value=i.default_)
class_view.attributes.append(attr)
for i in dot_class_info.methods.values():
visibility = UMLClassMethod.name_to_visibility(i.name)
method = UMLClassMethod(name=i.name, visibility=visibility, return_type=i.return_args.type_)
for j in i.args:
arg = UMLClassAttribute(name=j.name, value_type=j.type_, default_value=j.default_)
method.args.append(arg)
class_view = UMLClassView.load_dot_class_info(dot_class_info)
# update uml view
await self.graph_db.insert(ns_class_name, GraphKeyword.HAS_CLASS_VIEW, class_view.model_dump_json())

View file

@ -8,24 +8,47 @@
"""
from __future__ import annotations
import json
import re
from pathlib import Path
from typing import Dict, List, Optional
from typing import List, Optional
from pydantic import BaseModel
from tenacity import retry, stop_after_attempt, wait_random_exponential
from metagpt.actions import Action
from metagpt.config2 import config
from metagpt.const import AGGREGATION, GENERALIZATION, GRAPH_REPO_FILE_REPO
from metagpt.const import GRAPH_REPO_FILE_REPO
from metagpt.logs import logger
from metagpt.schema import ClassView
from metagpt.utils.common import aread, general_after_log, list_files, split_namespace
from metagpt.repo_parser import CodeBlockInfo, DotClassInfo
from metagpt.schema import UMLClassView
from metagpt.utils.common import (
aread,
concat_namespace,
general_after_log,
list_files,
parse_json_code_block,
read_file_block,
split_namespace,
)
from metagpt.utils.di_graph_repository import DiGraphRepository
from metagpt.utils.graph_repository import SPO, GraphKeyword, GraphRepository
class SQVUseCase(BaseModel):
description: str
inputs: List[str]
outputs: List[str]
actors: List[str]
steps: List[str]
reason: str
class SQVUseCaseDetails(BaseModel):
description: str
use_cases: List[SQVUseCase]
relationship: List[str]
class RebuildSequenceView(Action):
graph_db: Optional[GraphRepository] = None
@ -34,7 +57,9 @@ class RebuildSequenceView(Action):
self.graph_db = await DiGraphRepository.load_from(str(graph_repo_pathname.with_suffix(".json")))
entries = await self._search_main_entry()
for entry in entries:
await self._rebuild_sequence_view(entry)
await self._rebuild_main_sequence_view(entry)
while await self._merge_sequence_view(entry):
pass
await self.graph_db.save()
async def _search_main_entry(self) -> List:
@ -46,173 +71,163 @@ class RebuildSequenceView(Action):
entries.append(r)
return entries
async def _rebuild_sequence_view(self, entry):
filename = entry.subject.split(":", 1)[0]
src_filename = RebuildSequenceView._get_full_filename(root=self.i_context, pathname=filename)
if not src_filename:
return
content = await aread(filename=src_filename, encoding="utf-8")
content = f"```python\n{content}\n```\n\n---\nTranslate the code above into Mermaid Sequence Diagram."
sequence_view = await self.llm.aask(
msg=content, system_msgs=["You are a python code to Mermaid Sequence Diagram translator in function detail"]
)
await self.graph_db.insert(subject=filename, predicate=GraphKeyword.HAS_SEQUENCE_VIEW, object_=sequence_view)
logger.info(sequence_view)
merged_class_views = set()
while True:
participants = RebuildSequenceView.parse_participant(sequence_view)
class_views, class_compositions = await self._get_class_views(participants)
for compositions in class_compositions.values():
for c in compositions:
ns, _ = split_namespace(c.object_)
if ns == "?":
continue
await self._parse_class_description(c.object_)
diff = set()
for cv in class_views:
if cv.subject in merged_class_views:
continue
await self._parse_class_description(cv.subject)
sequence_view = await self._merge_sequence_view(
sequence_view, cv.subject, class_compositions.get(cv.subject, [])
)
diff.add(cv.subject)
merged_class_views.add(cv.subject)
await self.graph_db.delete(subject=filename, predicate=GraphKeyword.HAS_SEQUENCE_VIEW)
await self.graph_db.insert(
subject=filename, predicate=GraphKeyword.HAS_SEQUENCE_VIEW, object_=sequence_view
)
logger.info(sequence_view)
if diff:
continue
break
@retry(
wait=wait_random_exponential(min=1, max=20),
stop=stop_after_attempt(6),
after=general_after_log(logger),
)
async def _parse_class_description(self, ns_class_name: str):
rows = await self.graph_db.select(subject=ns_class_name, predicate=GraphKeyword.HAS_CLASS_DESC)
async def _rebuild_use_case(self, ns_class_name):
rows = await self.graph_db.select(subject=ns_class_name, predicate=GraphKeyword.HAS_CLASS_USE_CASE)
if rows:
return
me_class_views = await self.graph_db.select(subject=ns_class_name, predicate=GraphKeyword.HAS_CLASS_VIEW)
if not me_class_views:
# Loss of necessary information to create the description.
await self.graph_db.insert(subject=ns_class_name, predicate=GraphKeyword.HAS_CLASS_VIEW, object_="")
detail = await self._get_class_detail(ns_class_name)
if not detail:
return
participants = set()
participants.update(set(detail.compositions))
participants.update(set(detail.aggregations))
class_view = await self._get_uml_class_view(ns_class_name)
source_code = await self._get_source_code(ns_class_name)
# prepare base-class description
rows = await self.graph_db.select(
subject=ns_class_name, predicate=GraphKeyword.IS + GENERALIZATION + GraphKeyword.OF
)
ns_base_class_names = [r.object_ for r in rows]
ns_base_class_views = {}
ns_base_class_descs = {}
for name in ns_base_class_names:
class_views = await self.graph_db.select(subject=name, predicate=GraphKeyword.HAS_CLASS_VIEW)
if rows:
ns_base_class_views[name] = class_views
descs = await self.graph_db.select(subject=name, predicate=GraphKeyword.HAS_CLASS_DESC)
if not descs:
# Haven't been parsed before.
await self._parse_class_description(ns_class_name=name)
descs = await self.graph_db.select(subject=name, predicate=GraphKeyword.HAS_CLASS_DESC)
ns_base_class_descs[name] = descs
# parse class description
prompt = "```mermaid\nclassDiagram\n"
# - add base-class description
for ns_name in ns_base_class_names:
descs = ns_base_class_descs.get(ns_name, [])
for r in descs:
notes = self._desc_to_note(r.object_)
ns, name = split_namespace(r.subject)
for n in notes:
prompt += f'\n\tnote for {name} "{n}"'
views = ns_base_class_views.get(ns_name, [])
for r in views:
cv = ClassView.model_validate_json(r.object_)
prompt += "\n" + cv.get_mermaid(align=1)
# - add relationship
_, me = split_namespace(ns_class_name)
for ns_name in ns_base_class_names:
ns, base = split_namespace(ns_name)
prompt += f"\n\t{base} <|-- {me}"
# - add me
cv = ClassView.model_validate_json(me_class_views[0].object_)
prompt += "\n" + cv.get_mermaid(align=1)
prompt += "\n```"
prompt_blocks = []
block = "## Participants\n"
for p in participants:
block += f"- {p}\n"
prompt_blocks.append(block)
block = "## Mermaid Class Views\n```mermaid\n"
block += class_view.get_mermaid()
block += "\n```\n"
prompt_blocks.append(block)
block = "## Source Code\n```python\n"
block += source_code
block += "\n```\n"
prompt_blocks.append(block)
prompt = "\n---\n".join(prompt_blocks)
rsp = await self.llm.aask(
msg=prompt,
system_msgs=[
"You are a tool capable of translating class views into a textual description of their functionalities and goals.",
f'Please return a Markdown JSON format with a "description" key containing a concise textual description of the `{me}` class functionalities, a "goal" key containing a concise textual description of the problem the `{me}` class aims to solve, and a "reason" key explaining why.',
"You are a python code to UML 2.0 Use Case translator.",
'The generated UML 2.0 Use Case must include the roles or entities listed in "Participants".',
'The functional descriptions of Actors and Use Cases in the generated UML 2.0 Use Case must not conflict with the information in "Mermaid Class Views".',
#'Only steps that involve input, output, and interactive operations with the external system at the same time can be considered as independent use cases.',
"Only steps that involve input, output, and interactive operations with the external system at the same time can be considered as independent use cases, steps that do not meet any one condition should be incorporated into other use cases.",
'The section under `if __name__ == "__main__":` of "Source Code" contains information about external system interactions with the internal system.',
"Return a markdown JSON object with:\n"
'- a "description" key to explain what the whole source code want to do;\n'
'- a "use_cases" key list all use cases, each use case in the list should including a `description` key describes about what the use case to do, a `inputs` key lists the input names of the use case from external sources, a `outputs` key lists the output names of the use case to external sources, a `actors` key lists the participant actors of the use case, a `steps` key lists the steps about how the use case works step by step, a `reason` key explaining under what circumstances would the external system execute this use case.\n'
'- a "relationship" key lists the descriptions of relationship among these use cases.\n',
],
)
class _JsonCodeBlock(BaseModel):
description: str
goal: str
reason: Optional[str] = None
code_blocks = parse_json_code_block(rsp)
for block in code_blocks:
detail = SQVUseCaseDetails.model_validate_json(block)
await self.graph_db.insert(
subject=ns_class_name, predicate=GraphKeyword.HAS_CLASS_USE_CASE, object_=detail.model_dump_json()
)
code_block = rsp.removeprefix("```json\n").removesuffix("```")
data = _JsonCodeBlock.model_validate_json(code_block)
data.reason = None
functionality = data.model_dump_json(exclude_none=True)
await self.graph_db.insert(subject=ns_class_name, predicate=GraphKeyword.HAS_CLASS_DESC, object_=functionality)
async def _rebuild_main_sequence_view(self, entry):
filename = entry.subject.split(":", 1)[0]
rows = await self.graph_db.select(predicate=GraphKeyword.IS, object_=GraphKeyword.CLASS)
classes = []
prefix = filename + ":"
for r in rows:
if prefix in r.subject:
classes.append(r)
await self._rebuild_use_case(r.subject)
participants = set()
class_details = []
class_views = []
for c in classes:
detail = await self._get_class_detail(c.subject)
if not detail:
continue
class_details.append(detail)
participants.update(set(detail.compositions))
participants.update(set(detail.aggregations))
view = await self._get_uml_class_view(c.subject)
if view:
class_views.append(view)
use_case_blocks = []
for c in classes:
use_cases = await self._get_class_use_cases(c.subject)
use_case_blocks.extend(use_cases)
prompt_blocks = ["\n".join(use_case_blocks)]
block = "## Participants\n"
for p in participants:
block += f"- {p}\n"
prompt_blocks.append(block)
block = "## Mermaid Class Views\n```mermaid\n"
block += "\n\n".join([c.get_mermaid() for c in class_views])
block += "\n```\n"
prompt_blocks.append(block)
block = "## Source Code\n```python\n"
block += await self._get_source_code(filename)
block += "\n```\n"
prompt_blocks.append(block)
prompt = "\n---\n".join(prompt_blocks)
sequence_view = await self.llm.aask(
msg=prompt, system_msgs=["You are a python code to Mermaid Sequence Diagram translator in function detail"]
)
await self.graph_db.insert(
subject=entry.subject, predicate=GraphKeyword.HAS_SEQUENCE_VIEW, object_=sequence_view
)
for c in classes:
await self.graph_db.insert(subject=entry.subject, predicate=GraphKeyword.HAS_PARTICIPANT, object_=c.subject)
return sequence_view
async def _get_class_use_cases(self, ns_class_name) -> List[str]:
rows = await self.graph_db.select(subject=ns_class_name, predicate=GraphKeyword.HAS_CLASS_USE_CASE)
use_cases = []
for r in rows:
detail = SQVUseCaseDetails.model_validate_json(r.object_)
for i in detail.use_cases:
md = f"## Use Cases: {i.description}\n"
md += "### Inputs:\n" + "".join([f"- {i}\n" for i in i.inputs])
md += "### Outputs:\n" + "".join([f"- {i}\n" for i in i.outputs])
md += "### Actors:\n" + "".join([f"- {i}\n" for i in i.actors])
md += "### Steps:\n" + "".join([f"- {i}\n" for i in i.steps])
use_cases.append(md)
return use_cases
async def _get_class_detail(self, ns_class_name) -> DotClassInfo | None:
rows = await self.graph_db.select(subject=ns_class_name, predicate=GraphKeyword.HAS_DETAIL)
if not rows:
return None
dot_class_info = DotClassInfo.model_validate_json(rows[0].object_)
return dot_class_info
async def _get_uml_class_view(self, ns_class_name) -> UMLClassView | None:
rows = await self.graph_db.select(subject=ns_class_name, predicate=GraphKeyword.HAS_CLASS_VIEW)
if not rows:
return None
class_view = UMLClassView.model_validate_json(rows[0].object_)
return class_view
async def _get_source_code(self, ns_class_name) -> str:
rows = await self.graph_db.select(subject=ns_class_name, predicate=GraphKeyword.HAS_PAGE_INFO)
filename = split_namespace(ns_class_name=ns_class_name)[0]
if not rows:
src_filename = RebuildSequenceView._get_full_filename(root=self.i_context, pathname=filename)
if not src_filename:
return ""
return await aread(filename=src_filename, encoding="utf-8")
code_block_info = CodeBlockInfo.model_validate_json(rows[0].object_)
return await read_file_block(
filename=filename, lineno=code_block_info.lineno, end_lineno=code_block_info.end_lineno
)
@retry(
wait=wait_random_exponential(min=1, max=20),
stop=stop_after_attempt(6),
after=general_after_log(logger),
)
async def _merge_sequence_view(self, sequence_view, ns_class_name, compositions) -> str:
class_view_part = "```mermaid\n"
# add class
class_view_part += await self._class_view_to_mermaid(ns_class_name)
# add aggregation relationship
_, me = split_namespace(ns_class_name)
rows = await self.graph_db.select(
subject=ns_class_name, predicate=GraphKeyword.IS + AGGREGATION + GraphKeyword.OF
)
aggregation = [r.object_ for r in rows]
for ns_aggr_name in aggregation:
_, name = split_namespace(ns_aggr_name)
class_view_part += f"\n\t{me} *-- {name}"
class_view_part += await self._class_view_to_mermaid(ns_aggr_name)
# add composition relationship
for c in compositions:
_, name = split_namespace(c.object_)
class_view_part += f"\n\t{me} *-- {name}"
class_view_part += await self._class_view_to_mermaid(c.object_)
async def _merge_sequence_view(self, entry) -> bool:
new_participant = await self._search_new_participant(entry)
if not new_participant:
return False
class_view_part += "\n```"
contents = [
f"## Mermaid Class View\n{class_view_part}\n",
f"## Mermaid Sequence View\n{sequence_view}",
]
prompt = "\n---\n".join(contents)
rsp = await self.llm.aask(
msg=prompt,
system_msgs=[
"You are a tool to merge Mermaid class view information into the Mermaid sequence view.",
'Append as much information as possible from the "Mermaid Class View" to the sequence diagram.',
'Return a markdown JSON format with a "sequence_diagram" key containing the merged Mermaid sequence view, a "reason" key explaining in detail what information have been merged and why.',
],
)
class _JsonCodeBlock(BaseModel):
sequence_diagram: str
reason: str
code_block = rsp.removeprefix("```json\n").removesuffix("```")
data = _JsonCodeBlock.model_validate_json(code_block)
return data.sequence_diagram
await self._merge_participant(entry, new_participant)
return True
@staticmethod
def _get_full_filename(root: str | Path, pathname: str | Path) -> Path | None:
@ -225,48 +240,70 @@ class RebuildSequenceView(Action):
@staticmethod
def parse_participant(mermaid_sequence_diagram: str) -> List[str]:
pattern = r"participant (\w+)"
pattern = r"participant ([a-zA-Z\.0-9_]+)"
matches = re.findall(pattern, mermaid_sequence_diagram)
return matches
async def _get_class_views(self, class_names) -> (List[SPO], Dict[str, List[SPO]]):
rows = await self.graph_db.select(predicate=GraphKeyword.IS, object_=GraphKeyword.CLASS)
ns_class_names = {}
async def _search_new_participant(self, entry: SPO) -> str | None:
rows = await self.graph_db.select(subject=entry.subject, predicate=GraphKeyword.HAS_SEQUENCE_VIEW)
if not rows:
return None
sequence_view = rows[0].object_
rows = await self.graph_db.select(subject=entry.subject, predicate=GraphKeyword.HAS_PARTICIPANT)
merged_participants = []
for r in rows:
ns, class_name = split_namespace(r.subject)
if class_name in class_names:
ns_class_names[r.subject] = class_name
class_views = []
class_compositions = {}
for ns_name in ns_class_names.keys():
views = await self.graph_db.select(subject=ns_name, predicate=GraphKeyword.HAS_CLASS_VIEW)
if not views:
_, name = split_namespace(r.object_)
merged_participants.append(name)
participants = self.parse_participant(sequence_view)
for p in participants:
if p in merged_participants:
continue
class_views += views
compositions = await self.graph_db.select(subject=ns_name, predicate=GraphKeyword.IS_COMPOSITE_OF)
class_compositions[ns_name] = compositions
return class_views, class_compositions
return p
return None
@staticmethod
def _desc_to_note(json_str) -> List[str]:
if not json_str:
return []
m = json.loads(json_str)
return [s.replace('"', '\\"').replace("\n", "\\n") for s in m.values()]
async def _merge_participant(self, entry: SPO, class_name: str):
rows = await self.graph_db.select(predicate=GraphKeyword.IS, object_=GraphKeyword.CLASS)
participants = []
for r in rows:
_, name = split_namespace(r.subject)
if name == class_name:
participants.append(name)
if len(participants) == 0:
await self.graph_db.insert(
subject=entry.subject, predicate=GraphKeyword.HAS_PARTICIPANT, object_=concat_namespace("?", class_name)
)
return
if len(participants) > 1:
for r in participants:
await self.graph_db.insert(
subject=entry.subject, predicate=GraphKeyword.HAS_PARTICIPANT, object_=r.object_
)
return
async def _class_view_to_mermaid(self, ns_class_name) -> str:
class_view_rows = await self.graph_db.select(subject=ns_class_name, predicate=GraphKeyword.HAS_CLASS_VIEW)
if not class_view_rows:
return ""
result = ClassView.model_validate_json(class_view_rows[0].object_).get_mermaid() if class_view_rows else ""
_, name = split_namespace(ns_class_name)
class_desc_rows = await self.graph_db.select(subject=ns_class_name, predicate=GraphKeyword.HAS_CLASS_DESC)
if not class_desc_rows:
# Haven't been parsed before.
await self._parse_class_description(ns_class_name=ns_class_name)
class_desc_rows = await self.graph_db.select(subject=ns_class_name, predicate=GraphKeyword.HAS_CLASS_DESC)
for r in self._desc_to_note(class_desc_rows[0].object_):
result += f'\n\tnote for {name} "{r}"'
return result
participant = participants[0]
await self._rebuild_use_case(participant.subject)
participants = set()
detail = await self._get_class_detail(participant.subject)
if not detail:
return
participants.update(set(detail.compositions))
participants.update(set(detail.aggregations))
view = await self._get_uml_class_view(participant.subject)
use_cases = await self._get_class_use_cases(participant.subject)
prompt_blocks = ["\n".join(use_cases)]
block = "## Participants\n"
for p in participants:
block += f"- {p}\n"
prompt_blocks.append(block)
block = "## Mermaid Class Views\n```mermaid\n"
block += view.get_mermaid()
block += "\n```\n"
prompt_blocks.append(block)
block = "## Source Code\n```python\n"
block += await self._get_source_code(participant.subject)
block += "\n```\n"
prompt_blocks.append(block)
prompt = "\n---\n".join(prompt_blocks)
self.llm.aask(prompt, system_msgs=["You are a tool to cooperator"])

View file

@ -45,6 +45,7 @@ from metagpt.const import (
TASK_FILE_REPO,
)
from metagpt.logs import logger
from metagpt.repo_parser import DotClassInfo
from metagpt.utils.common import any_to_str, any_to_str_set, import_class
from metagpt.utils.exceptions import handle_exception
from metagpt.utils.serialize import (
@ -538,3 +539,19 @@ class UMLClassView(UMLClassMeta):
content += v.get_mermaid(align=align + 1) + "\n"
content += "".join(["\t" for i in range(align)]) + "}\n"
return content
@classmethod
def load_dot_class_info(cls, dot_class_info: DotClassInfo) -> UMLClassView:
visibility = UMLClassView.name_to_visibility(dot_class_info.name)
class_view = cls(name=dot_class_info.name, visibility=visibility)
for i in dot_class_info.attributes.values():
visibility = UMLClassAttribute.name_to_visibility(i.name)
attr = UMLClassAttribute(name=i.name, visibility=visibility, value_type=i.type_, default_value=i.default_)
class_view.attributes.append(attr)
for i in dot_class_info.methods.values():
visibility = UMLClassMethod.name_to_visibility(i.name)
method = UMLClassMethod(name=i.name, visibility=visibility, return_type=i.return_args.type_)
for j in i.args:
arg = UMLClassAttribute(name=j.name, value_type=j.type_, default_value=j.default_)
method.args.append(arg)
return class_view

View file

@ -37,8 +37,10 @@ class GraphKeyword:
HAS_PAGE_INFO = "has_page_info"
HAS_CLASS_VIEW = "has_class_view"
HAS_SEQUENCE_VIEW = "has_sequence_view"
HAS_CLASS_USE_CASE = "has_class_use_case"
IS_COMPOSITE_OF = "is_composite_of"
IS_AGGREGATE_OF = "is_aggregate_of"
HAS_PARTICIPANT = "has_participant"
class SPO(BaseModel):

File diff suppressed because one or more lines are too long