feat: +visual graph repo

This commit is contained in:
莘权 马 2024-01-31 23:40:04 +08:00
parent c1552d7319
commit fb7518c12b
13 changed files with 217 additions and 18 deletions

View file

@ -9,6 +9,7 @@
from __future__ import annotations
import re
from datetime import datetime
from pathlib import Path
from typing import List, Optional
@ -22,7 +23,9 @@ from metagpt.logs import logger
from metagpt.repo_parser import CodeBlockInfo, DotClassInfo
from metagpt.schema import UMLClassView
from metagpt.utils.common import (
add_affix,
aread,
auto_namespace,
concat_namespace,
general_after_log,
list_files,
@ -119,11 +122,22 @@ class RebuildSequenceView(Action):
],
)
sequence_view = rsp.removeprefix("```mermaid").removesuffix("```")
rows = await self.graph_db.select(subject=entry.subject, predicate=GraphKeyword.HAS_SEQUENCE_VIEW)
for r in rows:
await self.graph_db.delete(subject=r.subject, predicate=r.predicate, object_=r.object_)
await self.graph_db.insert(
subject=entry.subject, predicate=GraphKeyword.HAS_SEQUENCE_VIEW, object_=sequence_view
)
await self.graph_db.insert(
subject=entry.subject,
predicate=GraphKeyword.HAS_SEQUENCE_VIEW_VER,
object_=concat_namespace(datetime.now().strftime("%Y%m%d%H%M%S%f")[:-3], add_affix(sequence_view)),
)
for c in classes:
await self.graph_db.insert(subject=entry.subject, predicate=GraphKeyword.HAS_PARTICIPANT, object_=c.subject)
await self.graph_db.insert(
subject=entry.subject, predicate=GraphKeyword.HAS_PARTICIPANT, object_=auto_namespace(c.subject)
)
await self.graph_db.save()
async def _merge_sequence_view(self, entry) -> bool:
new_participant = await self._search_new_participant(entry)
@ -198,6 +212,7 @@ class RebuildSequenceView(Action):
await self.graph_db.insert(
subject=ns_class_name, predicate=GraphKeyword.HAS_CLASS_USE_CASE, object_=detail.model_dump_json()
)
await self.graph_db.save()
@retry(
wait=wait_random_exponential(min=1, max=20),
@ -211,6 +226,7 @@ class RebuildSequenceView(Action):
use_case_markdown = await self._get_class_use_cases(ns_class_name)
if not use_case_markdown: # external class
await self.graph_db.insert(subject=ns_class_name, predicate=GraphKeyword.HAS_SEQUENCE_VIEW, object_="")
await self.graph_db.save()
return
block = f"## Use Cases\n{use_case_markdown}"
prompts_blocks.append(block)
@ -244,6 +260,7 @@ class RebuildSequenceView(Action):
await self.graph_db.insert(
subject=ns_class_name, predicate=GraphKeyword.HAS_SEQUENCE_VIEW, object_=sequence_view
)
await self.graph_db.save()
async def _get_participants(self, ns_class_name) -> List[str]:
participants = set()
@ -319,7 +336,7 @@ class RebuildSequenceView(Action):
rows = await self.graph_db.select(subject=entry.subject, predicate=GraphKeyword.HAS_PARTICIPANT)
merged_participants = []
for r in rows:
_, name = split_namespace(r.object_)
name = split_namespace(r.object_)[-1]
merged_participants.append(name)
participants = self.parse_participant(sequence_view)
for p in participants:
@ -337,19 +354,21 @@ class RebuildSequenceView(Action):
rows = await self.graph_db.select(predicate=GraphKeyword.IS, object_=GraphKeyword.CLASS)
participants = []
for r in rows:
_, name = split_namespace(r.subject)
name = split_namespace(r.subject)[-1]
if name == class_name:
participants.append(r)
if len(participants) == 0: # external participants
await self.graph_db.insert(
subject=entry.subject, predicate=GraphKeyword.HAS_PARTICIPANT, object_=concat_namespace("?", class_name)
)
await self.graph_db.save()
return
if len(participants) > 1:
for r in participants:
await self.graph_db.insert(
subject=entry.subject, predicate=GraphKeyword.HAS_PARTICIPANT, object_=r.object_
subject=entry.subject, predicate=GraphKeyword.HAS_PARTICIPANT, object_=auto_namespace(r.subject)
)
await self.graph_db.save()
return
participant = participants[0]
@ -372,9 +391,18 @@ class RebuildSequenceView(Action):
)
sequence_view = rsp.removeprefix("```mermaid").removesuffix("```")
rows = await self.graph_db.select(subject=entry.subject, predicate=GraphKeyword.HAS_SEQUENCE_VIEW)
for r in rows:
await self.graph_db.delete(subject=r.subject, predicate=r.predicate, object_=r.object_)
await self.graph_db.insert(
subject=entry.subject, predicate=GraphKeyword.HAS_SEQUENCE_VIEW, object_=sequence_view
)
await self.graph_db.insert(
subject=entry.subject, predicate=GraphKeyword.HAS_PARTICIPANT, object_=participant.subject
subject=entry.subject,
predicate=GraphKeyword.HAS_SEQUENCE_VIEW_VER,
object_=concat_namespace(datetime.now().strftime("%Y%m%d%H%M%S%f")[:-3], add_affix(sequence_view)),
)
await self.graph_db.insert(
subject=entry.subject, predicate=GraphKeyword.HAS_PARTICIPANT, object_=auto_namespace(participant.subject)
)
await self.graph_db.save()

View file

@ -103,6 +103,7 @@ CODE_SUMMARIES_PDF_FILE_REPO = "resources/code_summary"
RESOURCES_FILE_REPO = "resources"
SD_OUTPUT_FILE_REPO = "resources/sd_output"
GRAPH_REPO_FILE_REPO = "docs/graph_repo"
VISUAL_GRAPH_REPO_FILE_REPO = "resources/graph_db"
CLASS_VIEW_FILE_REPO = "docs/class_view"
YAPI_URL = "http://yapi.deepwisdomai.com/"

View file

@ -79,7 +79,7 @@ class DotClassAttribute(BaseModel):
composition_val = pre_l + "Literal" + post_l # replace Literal[...] with Literal
type_ = pre_l + literal + post_l
else:
type_ = re.sub(r"['\"]", "", type_) # remove '"
type_ = re.sub(r"['\"]+", "", type_) # remove '"
composition_val = type_
if default_ == "None":
@ -95,7 +95,7 @@ class DotClassAttribute(BaseModel):
def parse_compositions(types_part) -> List[str]:
if not types_part:
return []
modified_string = re.sub(r"[\[\],]", "|", types_part)
modified_string = re.sub(r"[\[\],\(\)]", "|", types_part)
types = modified_string.split("|")
filters = {
"str",
@ -121,7 +121,7 @@ class DotClassAttribute(BaseModel):
}
result = set()
for t in types:
t = re.sub(r"['\"]", "", t.strip())
t = re.sub(r"['\"]+", "", t.strip())
if t and t not in filters:
result.add(t)
return list(result)

View file

@ -24,6 +24,7 @@ import traceback
import typing
from pathlib import Path
from typing import Any, List, Tuple, Union
from urllib.parse import quote, unquote
import aiofiles
import loguru
@ -411,6 +412,30 @@ def split_namespace(ns_class_name: str, maxsplit=1) -> List[str]:
return ns_class_name.split(":", maxsplit=maxsplit)
def auto_namespace(name: str) -> str:
if not name:
return "?:?"
v = split_namespace(name)
if len(v) < 2:
return f"?:{name}"
return name
def add_affix(text, affix="brace"):
mappings = {
"brace": lambda x: "{" + x + "}",
"url": lambda x: quote("{" + x + "}"),
}
encoder = mappings.get(affix, lambda x: x)
return encoder(text)
def remove_affix(text, affix="brace"):
mappings = {"brace": lambda x: x[1:-1], "url": lambda x: unquote(x)[1:-1]}
decoder = mappings.get(affix, lambda x: x)
return decoder(text)
def general_after_log(i: "loguru.Logger", sec_format: str = "%0.3f") -> typing.Callable[["RetryCallState"], None]:
"""
Generates a logging function to be used after a call is retried.

View file

@ -82,3 +82,7 @@ class DiGraphRepository(GraphRepository):
def pathname(self) -> Path:
p = Path(self.root) / self.name
return p.with_suffix(".json")
@property
def repo(self):
return self._repo

View file

@ -37,6 +37,7 @@ class GraphKeyword:
HAS_PAGE_INFO = "has_page_info"
HAS_CLASS_VIEW = "has_class_view"
HAS_SEQUENCE_VIEW = "has_sequence_view"
HAS_SEQUENCE_VIEW_VER = "has_sequence_view_ver"
HAS_CLASS_USE_CASE = "has_class_use_case"
IS_COMPOSITE_OF = "is_composite_of"
IS_AGGREGATE_OF = "is_aggregate_of"
@ -216,7 +217,7 @@ class GraphRepository(ABC):
classes = await graph_db.select(predicate=GraphKeyword.IS, object_=GraphKeyword.CLASS)
mapping = defaultdict(list)
for c in classes:
_, name = split_namespace(c.subject)
name = split_namespace(c.subject)[-1]
mapping[name].append(c.subject)
rows = await graph_db.select(predicate=GraphKeyword.IS_COMPOSITE_OF)

View file

@ -33,6 +33,7 @@ from metagpt.const import (
TASK_PDF_FILE_REPO,
TEST_CODES_FILE_REPO,
TEST_OUTPUTS_FILE_REPO,
VISUAL_GRAPH_REPO_FILE_REPO,
)
from metagpt.utils.file_repository import FileRepository
from metagpt.utils.git_repository import GitRepository
@ -69,6 +70,7 @@ class ResourceFileRepositories(FileRepository):
code_summary: FileRepository
sd_output: FileRepository
code_plan_and_change: FileRepository
graph_repo: FileRepository
def __init__(self, git_repo):
super().__init__(git_repo=git_repo, relative_path=RESOURCES_FILE_REPO)
@ -82,6 +84,7 @@ class ResourceFileRepositories(FileRepository):
self.code_summary = git_repo.new_file_repository(relative_path=CODE_SUMMARIES_PDF_FILE_REPO)
self.sd_output = git_repo.new_file_repository(relative_path=SD_OUTPUT_FILE_REPO)
self.code_plan_and_change = git_repo.new_file_repository(relative_path=CODE_PLAN_AND_CHANGE_PDF_FILE_REPO)
self.graph_repo = git_repo.new_file_repository(relative_path=VISUAL_GRAPH_REPO_FILE_REPO)
class ProjectRepo(FileRepository):

View file

@ -0,0 +1,110 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2023/12/19
@Author : mashenquan
@File : visualize_graph.py
@Desc : Visualize the graph.
"""
from __future__ import annotations
import re
from abc import ABC
from pathlib import Path
from typing import List, Optional
from pydantic import BaseModel, Field
from metagpt.const import AGGREGATION, COMPOSITION, GENERALIZATION
from metagpt.schema import UMLClassView
from metagpt.utils.common import split_namespace
from metagpt.utils.di_graph_repository import DiGraphRepository
from metagpt.utils.graph_repository import GraphKeyword, GraphRepository
class _VisualClassView(BaseModel):
package: str
uml: Optional[UMLClassView] = None
generalizations: List[str] = Field(default_factory=list)
compositions: List[str] = Field(default_factory=list)
aggregations: List[str] = Field(default_factory=list)
def get_mermaid(self, align: int = 1):
if not self.uml:
return ""
prefix = "\t" * align
mermaid_txt = self.uml.get_mermaid(align=align)
for i in self.generalizations:
mermaid_txt += f"{prefix}{i} <|-- {self.name}\n"
for i in self.compositions:
mermaid_txt += f"{prefix}{i} *-- {self.name}\n"
for i in self.aggregations:
mermaid_txt += f"{prefix}{i} o-- {self.name}\n"
return mermaid_txt
@property
def name(self) -> str:
return split_namespace(self.package)[-1]
class VisualGraphRepo(ABC):
graph_db: GraphRepository
def __init__(self, graph_db):
self.graph_db = graph_db
class VisualDiGraphRepo(VisualGraphRepo):
@classmethod
async def load_from(cls, filename: str | Path):
graph_db = await DiGraphRepository.load_from(str(filename))
return cls(graph_db=graph_db)
async def get_mermaid_class_view(self) -> str:
rows = await self.graph_db.select(predicate=GraphKeyword.IS, object_=GraphKeyword.CLASS)
mermaid_txt = "classDiagram\n"
for r in rows:
v = await self._get_class_view(ns_class_name=r.subject)
mermaid_txt += v.get_mermaid()
return mermaid_txt
async def _get_class_view(self, ns_class_name: str) -> _VisualClassView:
rows = await self.graph_db.select(subject=ns_class_name)
class_view = _VisualClassView(package=ns_class_name)
for r in rows:
if r.predicate == GraphKeyword.HAS_CLASS_VIEW:
class_view.uml = UMLClassView.model_validate_json(r.object_)
elif r.predicate == GraphKeyword.IS + GENERALIZATION + GraphKeyword.OF:
name = split_namespace(r.object_)[-1]
name = self._refine_name(name)
if name:
class_view.generalizations.append(name)
elif r.predicate == GraphKeyword.IS + COMPOSITION + GraphKeyword.OF:
name = split_namespace(r.object_)[-1]
name = self._refine_name(name)
if name:
class_view.compositions.append(name)
elif r.predicate == GraphKeyword.IS + AGGREGATION + GraphKeyword.OF:
name = split_namespace(r.object_)[-1]
name = self._refine_name(name)
if name:
class_view.aggregations.append(name)
return class_view
async def get_mermaid_sequence_views(self) -> List[(str, str)]:
sequence_views = []
rows = await self.graph_db.select(predicate=GraphKeyword.HAS_SEQUENCE_VIEW)
for r in rows:
sequence_views.append((r.subject, r.object_))
return sequence_views
@staticmethod
def _refine_name(name) -> str:
name = re.sub(r'^[\'"\\\(\)]+|[\'"\\\(\)]+$', "", name)
if name in ["int", "float", "bool", "str", "list", "tuple", "set", "dict", "None"]:
return ""
if "." in name:
name = name.split(".")[-1]
return name

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View file

@ -20,21 +20,21 @@ from metagpt.utils.graph_repository import SPO
@pytest.mark.asyncio
async def test_rebuild(context, mocker):
# Mock
data = await aread(filename=Path(__file__).parent / "../../data/graph_db/networkx.json")
data = await aread(filename=Path(__file__).parent / "../../data/graph_db/networkx.class_view.json")
graph_db_filename = Path(context.repo.workdir.name).with_suffix(".json")
await context.repo.docs.graph_repo.save(filename=str(graph_db_filename), content=data)
context.git_repo.add_change({f"{GRAPH_REPO_FILE_REPO}/{graph_db_filename}": ChangeType.UNTRACTED})
context.git_repo.commit("commit1")
# mock_spo = SPO(
# subject="metagpt/startup.py:__name__:__main__",
# predicate="has_page_info",
# object_='{"lineno":78,"end_lineno":79,"type_name":"ast.If","tokens":["__name__","__main__"],"properties":{}}',
# )
mock_spo = SPO(
subject="metagpt/tools/search_engine_serpapi.py:__name__:__main__",
subject="metagpt/startup.py:__name__:__main__",
predicate="has_page_info",
object_='{"lineno":113,"end_lineno":116,"type_name":"ast.If","tokens":["__name__","__main__"],"properties":{}}',
object_='{"lineno":78,"end_lineno":79,"type_name":"ast.If","tokens":["__name__","__main__"],"properties":{}}',
)
# mock_spo = SPO(
# subject="metagpt/tools/search_engine_serpapi.py:__name__:__main__",
# predicate="has_page_info",
# object_='{"lineno":113,"end_lineno":116,"type_name":"ast.If","tokens":["__name__","__main__"],"properties":{}}',
# )
mocker.patch.object(RebuildSequenceView, "_search_main_entry", return_value=[mock_spo])
action = RebuildSequenceView(

View file

@ -0,0 +1,26 @@
import re
from pathlib import Path
import pytest
from metagpt.utils.visual_graph_repo import VisualDiGraphRepo
@pytest.mark.asyncio
async def test_visual_di_graph_repo(context, mocker):
filename = Path(__file__).parent / "../../data/graph_db/networkx.sequence_view.json"
repo = await VisualDiGraphRepo.load_from(filename=filename)
class_view = await repo.get_mermaid_class_view()
assert class_view
await context.repo.resources.graph_repo.save(filename="class_view.md", content=f"```mermaid\n{class_view}\n```\n")
sequence_views = await repo.get_mermaid_sequence_views()
assert sequence_views
for ns, sqv in sequence_views:
filename = re.sub(r"[:/\\\.]+", "_", ns) + ".sequence_view.md"
await context.repo.resources.graph_repo.save(filename=filename, content=f"```mermaid\n{sqv}\n```\n")
if __name__ == "__main__":
pytest.main([__file__, "-s"])