feat: +visual graph repo

This commit is contained in:
莘权 马 2024-01-31 23:40:04 +08:00
parent c1552d7319
commit fb7518c12b
13 changed files with 217 additions and 18 deletions

View file

@ -24,6 +24,7 @@ import traceback
import typing
from pathlib import Path
from typing import Any, List, Tuple, Union
from urllib.parse import quote, unquote
import aiofiles
import loguru
@ -411,6 +412,30 @@ def split_namespace(ns_class_name: str, maxsplit=1) -> List[str]:
return ns_class_name.split(":", maxsplit=maxsplit)
def auto_namespace(name: str) -> str:
if not name:
return "?:?"
v = split_namespace(name)
if len(v) < 2:
return f"?:{name}"
return name
def add_affix(text, affix="brace"):
mappings = {
"brace": lambda x: "{" + x + "}",
"url": lambda x: quote("{" + x + "}"),
}
encoder = mappings.get(affix, lambda x: x)
return encoder(text)
def remove_affix(text, affix="brace"):
mappings = {"brace": lambda x: x[1:-1], "url": lambda x: unquote(x)[1:-1]}
decoder = mappings.get(affix, lambda x: x)
return decoder(text)
def general_after_log(i: "loguru.Logger", sec_format: str = "%0.3f") -> typing.Callable[["RetryCallState"], None]:
"""
Generates a logging function to be used after a call is retried.

View file

@ -82,3 +82,7 @@ class DiGraphRepository(GraphRepository):
def pathname(self) -> Path:
p = Path(self.root) / self.name
return p.with_suffix(".json")
@property
def repo(self):
return self._repo

View file

@ -37,6 +37,7 @@ class GraphKeyword:
HAS_PAGE_INFO = "has_page_info"
HAS_CLASS_VIEW = "has_class_view"
HAS_SEQUENCE_VIEW = "has_sequence_view"
HAS_SEQUENCE_VIEW_VER = "has_sequence_view_ver"
HAS_CLASS_USE_CASE = "has_class_use_case"
IS_COMPOSITE_OF = "is_composite_of"
IS_AGGREGATE_OF = "is_aggregate_of"
@ -216,7 +217,7 @@ class GraphRepository(ABC):
classes = await graph_db.select(predicate=GraphKeyword.IS, object_=GraphKeyword.CLASS)
mapping = defaultdict(list)
for c in classes:
_, name = split_namespace(c.subject)
name = split_namespace(c.subject)[-1]
mapping[name].append(c.subject)
rows = await graph_db.select(predicate=GraphKeyword.IS_COMPOSITE_OF)

View file

@ -33,6 +33,7 @@ from metagpt.const import (
TASK_PDF_FILE_REPO,
TEST_CODES_FILE_REPO,
TEST_OUTPUTS_FILE_REPO,
VISUAL_GRAPH_REPO_FILE_REPO,
)
from metagpt.utils.file_repository import FileRepository
from metagpt.utils.git_repository import GitRepository
@ -69,6 +70,7 @@ class ResourceFileRepositories(FileRepository):
code_summary: FileRepository
sd_output: FileRepository
code_plan_and_change: FileRepository
graph_repo: FileRepository
def __init__(self, git_repo):
super().__init__(git_repo=git_repo, relative_path=RESOURCES_FILE_REPO)
@ -82,6 +84,7 @@ class ResourceFileRepositories(FileRepository):
self.code_summary = git_repo.new_file_repository(relative_path=CODE_SUMMARIES_PDF_FILE_REPO)
self.sd_output = git_repo.new_file_repository(relative_path=SD_OUTPUT_FILE_REPO)
self.code_plan_and_change = git_repo.new_file_repository(relative_path=CODE_PLAN_AND_CHANGE_PDF_FILE_REPO)
self.graph_repo = git_repo.new_file_repository(relative_path=VISUAL_GRAPH_REPO_FILE_REPO)
class ProjectRepo(FileRepository):

View file

@ -0,0 +1,110 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2023/12/19
@Author : mashenquan
@File : visualize_graph.py
@Desc : Visualize the graph.
"""
from __future__ import annotations
import re
from abc import ABC
from pathlib import Path
from typing import List, Optional
from pydantic import BaseModel, Field
from metagpt.const import AGGREGATION, COMPOSITION, GENERALIZATION
from metagpt.schema import UMLClassView
from metagpt.utils.common import split_namespace
from metagpt.utils.di_graph_repository import DiGraphRepository
from metagpt.utils.graph_repository import GraphKeyword, GraphRepository
class _VisualClassView(BaseModel):
package: str
uml: Optional[UMLClassView] = None
generalizations: List[str] = Field(default_factory=list)
compositions: List[str] = Field(default_factory=list)
aggregations: List[str] = Field(default_factory=list)
def get_mermaid(self, align: int = 1):
if not self.uml:
return ""
prefix = "\t" * align
mermaid_txt = self.uml.get_mermaid(align=align)
for i in self.generalizations:
mermaid_txt += f"{prefix}{i} <|-- {self.name}\n"
for i in self.compositions:
mermaid_txt += f"{prefix}{i} *-- {self.name}\n"
for i in self.aggregations:
mermaid_txt += f"{prefix}{i} o-- {self.name}\n"
return mermaid_txt
@property
def name(self) -> str:
return split_namespace(self.package)[-1]
class VisualGraphRepo(ABC):
graph_db: GraphRepository
def __init__(self, graph_db):
self.graph_db = graph_db
class VisualDiGraphRepo(VisualGraphRepo):
@classmethod
async def load_from(cls, filename: str | Path):
graph_db = await DiGraphRepository.load_from(str(filename))
return cls(graph_db=graph_db)
async def get_mermaid_class_view(self) -> str:
rows = await self.graph_db.select(predicate=GraphKeyword.IS, object_=GraphKeyword.CLASS)
mermaid_txt = "classDiagram\n"
for r in rows:
v = await self._get_class_view(ns_class_name=r.subject)
mermaid_txt += v.get_mermaid()
return mermaid_txt
async def _get_class_view(self, ns_class_name: str) -> _VisualClassView:
rows = await self.graph_db.select(subject=ns_class_name)
class_view = _VisualClassView(package=ns_class_name)
for r in rows:
if r.predicate == GraphKeyword.HAS_CLASS_VIEW:
class_view.uml = UMLClassView.model_validate_json(r.object_)
elif r.predicate == GraphKeyword.IS + GENERALIZATION + GraphKeyword.OF:
name = split_namespace(r.object_)[-1]
name = self._refine_name(name)
if name:
class_view.generalizations.append(name)
elif r.predicate == GraphKeyword.IS + COMPOSITION + GraphKeyword.OF:
name = split_namespace(r.object_)[-1]
name = self._refine_name(name)
if name:
class_view.compositions.append(name)
elif r.predicate == GraphKeyword.IS + AGGREGATION + GraphKeyword.OF:
name = split_namespace(r.object_)[-1]
name = self._refine_name(name)
if name:
class_view.aggregations.append(name)
return class_view
async def get_mermaid_sequence_views(self) -> List[(str, str)]:
sequence_views = []
rows = await self.graph_db.select(predicate=GraphKeyword.HAS_SEQUENCE_VIEW)
for r in rows:
sequence_views.append((r.subject, r.object_))
return sequence_views
@staticmethod
def _refine_name(name) -> str:
name = re.sub(r'^[\'"\\\(\)]+|[\'"\\\(\)]+$', "", name)
if name in ["int", "float", "bool", "str", "list", "tuple", "set", "dict", "None"]:
return ""
if "." in name:
name = name.split(".")[-1]
return name