mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-06-05 14:55:18 +02:00
feat: +visual graph repo
This commit is contained in:
parent
c1552d7319
commit
fb7518c12b
13 changed files with 217 additions and 18 deletions
|
|
@ -9,6 +9,7 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
|
||||
|
|
@ -22,7 +23,9 @@ from metagpt.logs import logger
|
|||
from metagpt.repo_parser import CodeBlockInfo, DotClassInfo
|
||||
from metagpt.schema import UMLClassView
|
||||
from metagpt.utils.common import (
|
||||
add_affix,
|
||||
aread,
|
||||
auto_namespace,
|
||||
concat_namespace,
|
||||
general_after_log,
|
||||
list_files,
|
||||
|
|
@ -119,11 +122,22 @@ class RebuildSequenceView(Action):
|
|||
],
|
||||
)
|
||||
sequence_view = rsp.removeprefix("```mermaid").removesuffix("```")
|
||||
rows = await self.graph_db.select(subject=entry.subject, predicate=GraphKeyword.HAS_SEQUENCE_VIEW)
|
||||
for r in rows:
|
||||
await self.graph_db.delete(subject=r.subject, predicate=r.predicate, object_=r.object_)
|
||||
await self.graph_db.insert(
|
||||
subject=entry.subject, predicate=GraphKeyword.HAS_SEQUENCE_VIEW, object_=sequence_view
|
||||
)
|
||||
await self.graph_db.insert(
|
||||
subject=entry.subject,
|
||||
predicate=GraphKeyword.HAS_SEQUENCE_VIEW_VER,
|
||||
object_=concat_namespace(datetime.now().strftime("%Y%m%d%H%M%S%f")[:-3], add_affix(sequence_view)),
|
||||
)
|
||||
for c in classes:
|
||||
await self.graph_db.insert(subject=entry.subject, predicate=GraphKeyword.HAS_PARTICIPANT, object_=c.subject)
|
||||
await self.graph_db.insert(
|
||||
subject=entry.subject, predicate=GraphKeyword.HAS_PARTICIPANT, object_=auto_namespace(c.subject)
|
||||
)
|
||||
await self.graph_db.save()
|
||||
|
||||
async def _merge_sequence_view(self, entry) -> bool:
|
||||
new_participant = await self._search_new_participant(entry)
|
||||
|
|
@ -198,6 +212,7 @@ class RebuildSequenceView(Action):
|
|||
await self.graph_db.insert(
|
||||
subject=ns_class_name, predicate=GraphKeyword.HAS_CLASS_USE_CASE, object_=detail.model_dump_json()
|
||||
)
|
||||
await self.graph_db.save()
|
||||
|
||||
@retry(
|
||||
wait=wait_random_exponential(min=1, max=20),
|
||||
|
|
@ -211,6 +226,7 @@ class RebuildSequenceView(Action):
|
|||
use_case_markdown = await self._get_class_use_cases(ns_class_name)
|
||||
if not use_case_markdown: # external class
|
||||
await self.graph_db.insert(subject=ns_class_name, predicate=GraphKeyword.HAS_SEQUENCE_VIEW, object_="")
|
||||
await self.graph_db.save()
|
||||
return
|
||||
block = f"## Use Cases\n{use_case_markdown}"
|
||||
prompts_blocks.append(block)
|
||||
|
|
@ -244,6 +260,7 @@ class RebuildSequenceView(Action):
|
|||
await self.graph_db.insert(
|
||||
subject=ns_class_name, predicate=GraphKeyword.HAS_SEQUENCE_VIEW, object_=sequence_view
|
||||
)
|
||||
await self.graph_db.save()
|
||||
|
||||
async def _get_participants(self, ns_class_name) -> List[str]:
|
||||
participants = set()
|
||||
|
|
@ -319,7 +336,7 @@ class RebuildSequenceView(Action):
|
|||
rows = await self.graph_db.select(subject=entry.subject, predicate=GraphKeyword.HAS_PARTICIPANT)
|
||||
merged_participants = []
|
||||
for r in rows:
|
||||
_, name = split_namespace(r.object_)
|
||||
name = split_namespace(r.object_)[-1]
|
||||
merged_participants.append(name)
|
||||
participants = self.parse_participant(sequence_view)
|
||||
for p in participants:
|
||||
|
|
@ -337,19 +354,21 @@ class RebuildSequenceView(Action):
|
|||
rows = await self.graph_db.select(predicate=GraphKeyword.IS, object_=GraphKeyword.CLASS)
|
||||
participants = []
|
||||
for r in rows:
|
||||
_, name = split_namespace(r.subject)
|
||||
name = split_namespace(r.subject)[-1]
|
||||
if name == class_name:
|
||||
participants.append(r)
|
||||
if len(participants) == 0: # external participants
|
||||
await self.graph_db.insert(
|
||||
subject=entry.subject, predicate=GraphKeyword.HAS_PARTICIPANT, object_=concat_namespace("?", class_name)
|
||||
)
|
||||
await self.graph_db.save()
|
||||
return
|
||||
if len(participants) > 1:
|
||||
for r in participants:
|
||||
await self.graph_db.insert(
|
||||
subject=entry.subject, predicate=GraphKeyword.HAS_PARTICIPANT, object_=r.object_
|
||||
subject=entry.subject, predicate=GraphKeyword.HAS_PARTICIPANT, object_=auto_namespace(r.subject)
|
||||
)
|
||||
await self.graph_db.save()
|
||||
return
|
||||
|
||||
participant = participants[0]
|
||||
|
|
@ -372,9 +391,18 @@ class RebuildSequenceView(Action):
|
|||
)
|
||||
|
||||
sequence_view = rsp.removeprefix("```mermaid").removesuffix("```")
|
||||
rows = await self.graph_db.select(subject=entry.subject, predicate=GraphKeyword.HAS_SEQUENCE_VIEW)
|
||||
for r in rows:
|
||||
await self.graph_db.delete(subject=r.subject, predicate=r.predicate, object_=r.object_)
|
||||
await self.graph_db.insert(
|
||||
subject=entry.subject, predicate=GraphKeyword.HAS_SEQUENCE_VIEW, object_=sequence_view
|
||||
)
|
||||
await self.graph_db.insert(
|
||||
subject=entry.subject, predicate=GraphKeyword.HAS_PARTICIPANT, object_=participant.subject
|
||||
subject=entry.subject,
|
||||
predicate=GraphKeyword.HAS_SEQUENCE_VIEW_VER,
|
||||
object_=concat_namespace(datetime.now().strftime("%Y%m%d%H%M%S%f")[:-3], add_affix(sequence_view)),
|
||||
)
|
||||
await self.graph_db.insert(
|
||||
subject=entry.subject, predicate=GraphKeyword.HAS_PARTICIPANT, object_=auto_namespace(participant.subject)
|
||||
)
|
||||
await self.graph_db.save()
|
||||
|
|
|
|||
|
|
@ -103,6 +103,7 @@ CODE_SUMMARIES_PDF_FILE_REPO = "resources/code_summary"
|
|||
RESOURCES_FILE_REPO = "resources"
|
||||
SD_OUTPUT_FILE_REPO = "resources/sd_output"
|
||||
GRAPH_REPO_FILE_REPO = "docs/graph_repo"
|
||||
VISUAL_GRAPH_REPO_FILE_REPO = "resources/graph_db"
|
||||
CLASS_VIEW_FILE_REPO = "docs/class_view"
|
||||
|
||||
YAPI_URL = "http://yapi.deepwisdomai.com/"
|
||||
|
|
|
|||
|
|
@ -79,7 +79,7 @@ class DotClassAttribute(BaseModel):
|
|||
composition_val = pre_l + "Literal" + post_l # replace Literal[...] with Literal
|
||||
type_ = pre_l + literal + post_l
|
||||
else:
|
||||
type_ = re.sub(r"['\"]", "", type_) # remove '"
|
||||
type_ = re.sub(r"['\"]+", "", type_) # remove '"
|
||||
composition_val = type_
|
||||
|
||||
if default_ == "None":
|
||||
|
|
@ -95,7 +95,7 @@ class DotClassAttribute(BaseModel):
|
|||
def parse_compositions(types_part) -> List[str]:
|
||||
if not types_part:
|
||||
return []
|
||||
modified_string = re.sub(r"[\[\],]", "|", types_part)
|
||||
modified_string = re.sub(r"[\[\],\(\)]", "|", types_part)
|
||||
types = modified_string.split("|")
|
||||
filters = {
|
||||
"str",
|
||||
|
|
@ -121,7 +121,7 @@ class DotClassAttribute(BaseModel):
|
|||
}
|
||||
result = set()
|
||||
for t in types:
|
||||
t = re.sub(r"['\"]", "", t.strip())
|
||||
t = re.sub(r"['\"]+", "", t.strip())
|
||||
if t and t not in filters:
|
||||
result.add(t)
|
||||
return list(result)
|
||||
|
|
|
|||
|
|
@ -24,6 +24,7 @@ import traceback
|
|||
import typing
|
||||
from pathlib import Path
|
||||
from typing import Any, List, Tuple, Union
|
||||
from urllib.parse import quote, unquote
|
||||
|
||||
import aiofiles
|
||||
import loguru
|
||||
|
|
@ -411,6 +412,30 @@ def split_namespace(ns_class_name: str, maxsplit=1) -> List[str]:
|
|||
return ns_class_name.split(":", maxsplit=maxsplit)
|
||||
|
||||
|
||||
def auto_namespace(name: str) -> str:
|
||||
if not name:
|
||||
return "?:?"
|
||||
v = split_namespace(name)
|
||||
if len(v) < 2:
|
||||
return f"?:{name}"
|
||||
return name
|
||||
|
||||
|
||||
def add_affix(text, affix="brace"):
|
||||
mappings = {
|
||||
"brace": lambda x: "{" + x + "}",
|
||||
"url": lambda x: quote("{" + x + "}"),
|
||||
}
|
||||
encoder = mappings.get(affix, lambda x: x)
|
||||
return encoder(text)
|
||||
|
||||
|
||||
def remove_affix(text, affix="brace"):
|
||||
mappings = {"brace": lambda x: x[1:-1], "url": lambda x: unquote(x)[1:-1]}
|
||||
decoder = mappings.get(affix, lambda x: x)
|
||||
return decoder(text)
|
||||
|
||||
|
||||
def general_after_log(i: "loguru.Logger", sec_format: str = "%0.3f") -> typing.Callable[["RetryCallState"], None]:
|
||||
"""
|
||||
Generates a logging function to be used after a call is retried.
|
||||
|
|
|
|||
|
|
@ -82,3 +82,7 @@ class DiGraphRepository(GraphRepository):
|
|||
def pathname(self) -> Path:
|
||||
p = Path(self.root) / self.name
|
||||
return p.with_suffix(".json")
|
||||
|
||||
@property
|
||||
def repo(self):
|
||||
return self._repo
|
||||
|
|
|
|||
|
|
@ -37,6 +37,7 @@ class GraphKeyword:
|
|||
HAS_PAGE_INFO = "has_page_info"
|
||||
HAS_CLASS_VIEW = "has_class_view"
|
||||
HAS_SEQUENCE_VIEW = "has_sequence_view"
|
||||
HAS_SEQUENCE_VIEW_VER = "has_sequence_view_ver"
|
||||
HAS_CLASS_USE_CASE = "has_class_use_case"
|
||||
IS_COMPOSITE_OF = "is_composite_of"
|
||||
IS_AGGREGATE_OF = "is_aggregate_of"
|
||||
|
|
@ -216,7 +217,7 @@ class GraphRepository(ABC):
|
|||
classes = await graph_db.select(predicate=GraphKeyword.IS, object_=GraphKeyword.CLASS)
|
||||
mapping = defaultdict(list)
|
||||
for c in classes:
|
||||
_, name = split_namespace(c.subject)
|
||||
name = split_namespace(c.subject)[-1]
|
||||
mapping[name].append(c.subject)
|
||||
|
||||
rows = await graph_db.select(predicate=GraphKeyword.IS_COMPOSITE_OF)
|
||||
|
|
|
|||
|
|
@ -33,6 +33,7 @@ from metagpt.const import (
|
|||
TASK_PDF_FILE_REPO,
|
||||
TEST_CODES_FILE_REPO,
|
||||
TEST_OUTPUTS_FILE_REPO,
|
||||
VISUAL_GRAPH_REPO_FILE_REPO,
|
||||
)
|
||||
from metagpt.utils.file_repository import FileRepository
|
||||
from metagpt.utils.git_repository import GitRepository
|
||||
|
|
@ -69,6 +70,7 @@ class ResourceFileRepositories(FileRepository):
|
|||
code_summary: FileRepository
|
||||
sd_output: FileRepository
|
||||
code_plan_and_change: FileRepository
|
||||
graph_repo: FileRepository
|
||||
|
||||
def __init__(self, git_repo):
|
||||
super().__init__(git_repo=git_repo, relative_path=RESOURCES_FILE_REPO)
|
||||
|
|
@ -82,6 +84,7 @@ class ResourceFileRepositories(FileRepository):
|
|||
self.code_summary = git_repo.new_file_repository(relative_path=CODE_SUMMARIES_PDF_FILE_REPO)
|
||||
self.sd_output = git_repo.new_file_repository(relative_path=SD_OUTPUT_FILE_REPO)
|
||||
self.code_plan_and_change = git_repo.new_file_repository(relative_path=CODE_PLAN_AND_CHANGE_PDF_FILE_REPO)
|
||||
self.graph_repo = git_repo.new_file_repository(relative_path=VISUAL_GRAPH_REPO_FILE_REPO)
|
||||
|
||||
|
||||
class ProjectRepo(FileRepository):
|
||||
|
|
|
|||
110
metagpt/utils/visual_graph_repo.py
Normal file
110
metagpt/utils/visual_graph_repo.py
Normal file
|
|
@ -0,0 +1,110 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/12/19
|
||||
@Author : mashenquan
|
||||
@File : visualize_graph.py
|
||||
@Desc : Visualize the graph.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from abc import ABC
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from metagpt.const import AGGREGATION, COMPOSITION, GENERALIZATION
|
||||
from metagpt.schema import UMLClassView
|
||||
from metagpt.utils.common import split_namespace
|
||||
from metagpt.utils.di_graph_repository import DiGraphRepository
|
||||
from metagpt.utils.graph_repository import GraphKeyword, GraphRepository
|
||||
|
||||
|
||||
class _VisualClassView(BaseModel):
|
||||
package: str
|
||||
uml: Optional[UMLClassView] = None
|
||||
generalizations: List[str] = Field(default_factory=list)
|
||||
compositions: List[str] = Field(default_factory=list)
|
||||
aggregations: List[str] = Field(default_factory=list)
|
||||
|
||||
def get_mermaid(self, align: int = 1):
|
||||
if not self.uml:
|
||||
return ""
|
||||
prefix = "\t" * align
|
||||
|
||||
mermaid_txt = self.uml.get_mermaid(align=align)
|
||||
for i in self.generalizations:
|
||||
mermaid_txt += f"{prefix}{i} <|-- {self.name}\n"
|
||||
for i in self.compositions:
|
||||
mermaid_txt += f"{prefix}{i} *-- {self.name}\n"
|
||||
for i in self.aggregations:
|
||||
mermaid_txt += f"{prefix}{i} o-- {self.name}\n"
|
||||
return mermaid_txt
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return split_namespace(self.package)[-1]
|
||||
|
||||
|
||||
class VisualGraphRepo(ABC):
|
||||
graph_db: GraphRepository
|
||||
|
||||
def __init__(self, graph_db):
|
||||
self.graph_db = graph_db
|
||||
|
||||
|
||||
class VisualDiGraphRepo(VisualGraphRepo):
|
||||
@classmethod
|
||||
async def load_from(cls, filename: str | Path):
|
||||
graph_db = await DiGraphRepository.load_from(str(filename))
|
||||
return cls(graph_db=graph_db)
|
||||
|
||||
async def get_mermaid_class_view(self) -> str:
|
||||
rows = await self.graph_db.select(predicate=GraphKeyword.IS, object_=GraphKeyword.CLASS)
|
||||
mermaid_txt = "classDiagram\n"
|
||||
for r in rows:
|
||||
v = await self._get_class_view(ns_class_name=r.subject)
|
||||
mermaid_txt += v.get_mermaid()
|
||||
return mermaid_txt
|
||||
|
||||
async def _get_class_view(self, ns_class_name: str) -> _VisualClassView:
|
||||
rows = await self.graph_db.select(subject=ns_class_name)
|
||||
class_view = _VisualClassView(package=ns_class_name)
|
||||
for r in rows:
|
||||
if r.predicate == GraphKeyword.HAS_CLASS_VIEW:
|
||||
class_view.uml = UMLClassView.model_validate_json(r.object_)
|
||||
elif r.predicate == GraphKeyword.IS + GENERALIZATION + GraphKeyword.OF:
|
||||
name = split_namespace(r.object_)[-1]
|
||||
name = self._refine_name(name)
|
||||
if name:
|
||||
class_view.generalizations.append(name)
|
||||
elif r.predicate == GraphKeyword.IS + COMPOSITION + GraphKeyword.OF:
|
||||
name = split_namespace(r.object_)[-1]
|
||||
name = self._refine_name(name)
|
||||
if name:
|
||||
class_view.compositions.append(name)
|
||||
elif r.predicate == GraphKeyword.IS + AGGREGATION + GraphKeyword.OF:
|
||||
name = split_namespace(r.object_)[-1]
|
||||
name = self._refine_name(name)
|
||||
if name:
|
||||
class_view.aggregations.append(name)
|
||||
return class_view
|
||||
|
||||
async def get_mermaid_sequence_views(self) -> List[(str, str)]:
|
||||
sequence_views = []
|
||||
rows = await self.graph_db.select(predicate=GraphKeyword.HAS_SEQUENCE_VIEW)
|
||||
for r in rows:
|
||||
sequence_views.append((r.subject, r.object_))
|
||||
return sequence_views
|
||||
|
||||
@staticmethod
|
||||
def _refine_name(name) -> str:
|
||||
name = re.sub(r'^[\'"\\\(\)]+|[\'"\\\(\)]+$', "", name)
|
||||
if name in ["int", "float", "bool", "str", "list", "tuple", "set", "dict", "None"]:
|
||||
return ""
|
||||
if "." in name:
|
||||
name = name.split(".")[-1]
|
||||
|
||||
return name
|
||||
1
tests/data/graph_db/networkx.class_view.json
Normal file
1
tests/data/graph_db/networkx.class_view.json
Normal file
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
1
tests/data/graph_db/networkx.sequence_view.json
Normal file
1
tests/data/graph_db/networkx.sequence_view.json
Normal file
File diff suppressed because one or more lines are too long
|
|
@ -20,21 +20,21 @@ from metagpt.utils.graph_repository import SPO
|
|||
@pytest.mark.asyncio
|
||||
async def test_rebuild(context, mocker):
|
||||
# Mock
|
||||
data = await aread(filename=Path(__file__).parent / "../../data/graph_db/networkx.json")
|
||||
data = await aread(filename=Path(__file__).parent / "../../data/graph_db/networkx.class_view.json")
|
||||
graph_db_filename = Path(context.repo.workdir.name).with_suffix(".json")
|
||||
await context.repo.docs.graph_repo.save(filename=str(graph_db_filename), content=data)
|
||||
context.git_repo.add_change({f"{GRAPH_REPO_FILE_REPO}/{graph_db_filename}": ChangeType.UNTRACTED})
|
||||
context.git_repo.commit("commit1")
|
||||
# mock_spo = SPO(
|
||||
# subject="metagpt/startup.py:__name__:__main__",
|
||||
# predicate="has_page_info",
|
||||
# object_='{"lineno":78,"end_lineno":79,"type_name":"ast.If","tokens":["__name__","__main__"],"properties":{}}',
|
||||
# )
|
||||
mock_spo = SPO(
|
||||
subject="metagpt/tools/search_engine_serpapi.py:__name__:__main__",
|
||||
subject="metagpt/startup.py:__name__:__main__",
|
||||
predicate="has_page_info",
|
||||
object_='{"lineno":113,"end_lineno":116,"type_name":"ast.If","tokens":["__name__","__main__"],"properties":{}}',
|
||||
object_='{"lineno":78,"end_lineno":79,"type_name":"ast.If","tokens":["__name__","__main__"],"properties":{}}',
|
||||
)
|
||||
# mock_spo = SPO(
|
||||
# subject="metagpt/tools/search_engine_serpapi.py:__name__:__main__",
|
||||
# predicate="has_page_info",
|
||||
# object_='{"lineno":113,"end_lineno":116,"type_name":"ast.If","tokens":["__name__","__main__"],"properties":{}}',
|
||||
# )
|
||||
mocker.patch.object(RebuildSequenceView, "_search_main_entry", return_value=[mock_spo])
|
||||
|
||||
action = RebuildSequenceView(
|
||||
|
|
|
|||
26
tests/metagpt/utils/test_visual_graph_repo.py
Normal file
26
tests/metagpt/utils/test_visual_graph_repo.py
Normal file
|
|
@ -0,0 +1,26 @@
|
|||
import re
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.utils.visual_graph_repo import VisualDiGraphRepo
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_visual_di_graph_repo(context, mocker):
|
||||
filename = Path(__file__).parent / "../../data/graph_db/networkx.sequence_view.json"
|
||||
repo = await VisualDiGraphRepo.load_from(filename=filename)
|
||||
|
||||
class_view = await repo.get_mermaid_class_view()
|
||||
assert class_view
|
||||
await context.repo.resources.graph_repo.save(filename="class_view.md", content=f"```mermaid\n{class_view}\n```\n")
|
||||
|
||||
sequence_views = await repo.get_mermaid_sequence_views()
|
||||
assert sequence_views
|
||||
for ns, sqv in sequence_views:
|
||||
filename = re.sub(r"[:/\\\.]+", "_", ns) + ".sequence_view.md"
|
||||
await context.repo.resources.graph_repo.save(filename=filename, content=f"```mermaid\n{sqv}\n```\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-s"])
|
||||
Loading…
Add table
Add a link
Reference in a new issue