feat: +rebuild project

feat: parse pass
This commit is contained in:
莘权 马 2024-01-02 18:42:59 +08:00
parent aec2b72c8d
commit eabe0224c5
8 changed files with 214 additions and 60 deletions

1
.gitignore vendored
View file

@ -171,3 +171,4 @@ tests/metagpt/utils/file_repo_git
*.png
htmlcov
htmlcov.*
*.dot

View file

@ -9,52 +9,51 @@
import re
from pathlib import Path
import aiofiles
from metagpt.actions import Action
from metagpt.config import CONFIG
from metagpt.const import CLASS_VIEW_FILE_REPO, GRAPH_REPO_FILE_REPO
from metagpt.const import (
CLASS_VIEW_FILE_REPO,
DATA_API_DESIGN_FILE_REPO,
GRAPH_REPO_FILE_REPO,
)
from metagpt.repo_parser import RepoParser
from metagpt.utils.di_graph_repository import DiGraphRepository
from metagpt.utils.graph_repository import GraphKeyword, GraphRepository
class RebuildClassView(Action):
def __init__(self, name="", context=None, llm=None):
super().__init__(name=name, context=context, llm=llm)
async def run(self, with_messages=None, format=CONFIG.prompt_schema):
graph_repo_pathname = CONFIG.git_repo.workdir / GRAPH_REPO_FILE_REPO / CONFIG.git_repo.workdir.name
graph_db = await DiGraphRepository.load_from(str(graph_repo_pathname.with_suffix(".json")))
repo_parser = RepoParser(base_directory=self.context)
class_views = await repo_parser.rebuild_class_views(path=Path(self.context)) # use pylint
repo_parser = RepoParser(base_directory=Path(self.context))
class_views, relationship_views = await repo_parser.rebuild_class_views(path=Path(self.context)) # use pylint
await GraphRepository.update_graph_db_with_class_views(graph_db, class_views)
await GraphRepository.update_graph_db_with_class_relationship_views(graph_db, relationship_views)
symbols = repo_parser.generate_symbols() # use ast
for file_info in symbols:
await GraphRepository.update_graph_db_with_file_info(graph_db, file_info)
await self._create_mermaid_class_view(graph_db=graph_db)
# await graph_db.save(path=graph_repo_pathname.parent)
await self._create_mermaid_class_views(graph_db=graph_db)
await self._save(graph_db=graph_db)
async def _create_mermaid_class_view(self, graph_db):
pass
# dataset = await graph_db.select(subject=concat_namespace(filename, class_name), predicate=GraphKeyword.HAS_PAGE_INFO)
# if not dataset:
# logger.warning(f"No page info for {concat_namespace(filename, class_name)}")
# return
# code_block_info = CodeBlockInfo.parse_raw(dataset[0].object_)
# src_code = await read_file_block(filename=Path(self.context) / filename, lineno=code_block_info.lineno, end_lineno=code_block_info.end_lineno)
# code_type = ""
# dataset = await graph_db.select(subject=filename, predicate=GraphKeyword.IS)
# for spo in dataset:
# if spo.object_ in ["javascript", "python"]:
# code_type = spo.object_
# break
async def _create_mermaid_class_views(self, graph_db):
path = Path(CONFIG.git_repo.workdir) / DATA_API_DESIGN_FILE_REPO
path.mkdir(parents=True, exist_ok=True)
async with aiofiles.open(str(path / CONFIG.git_repo.workdir.name), mode="w", encoding="utf-8") as writer:
await writer.write("classDiagram\n")
# class names
rows = await graph_db.select(predicate=GraphKeyword.IS, object_=GraphKeyword.CLASS)
distinct = {}
for r in rows:
await RebuildClassView._create_mermaid_class(r, graph_db, writer, distinct)
# try:
# node = await REBUILD_CLASS_VIEW_NODE.fill(context=f"```{code_type}\n{src_code}\n```", llm=self.llm, to=format)
# class_view = node.instruct_content.model_dump()["Class View"]
# except Exception as e:
# class_view = RepoParser.rebuild_class_view(src_code, code_type)
# await graph_db.insert(subject=concat_namespace(filename, class_name), predicate=GraphKeyword.HAS_CLASS_VIEW, object_=class_view)
# logger.info(f"{concat_namespace(filename, class_name)} {GraphKeyword.HAS_CLASS_VIEW} {class_view}")
@staticmethod
async def _create_mermaid_class(ns_class_name, graph_db, file_writer, distinct):
pass
# fields = split_namespace(ns_class_name)
# await graph_db.select(subject=ns_class_name)
async def _save(self, graph_db):
class_view_file_repo = CONFIG.git_repo.new_file_repository(relative_path=CLASS_VIEW_FILE_REPO)

View file

@ -126,3 +126,8 @@ LLM_API_TIMEOUT = 300
# Message id
IGNORED_MESSAGE_ID = "0"
# Class Relationship
GENERALIZATION = "Generalize"
COMPOSITION = "Composite"
AGGREGATION = "Aggregate"

View file

@ -13,15 +13,15 @@ import re
import subprocess
from pathlib import Path
from pprint import pformat
from typing import Dict, List, Optional, Tuple
from typing import Dict, List, Optional
import aiofiles
import pandas as pd
from pydantic import BaseModel, Field
from metagpt.config import CONFIG
from metagpt.const import AGGREGATION, COMPOSITION, GENERALIZATION
from metagpt.logs import logger
from metagpt.utils.common import any_to_str
from metagpt.utils.common import any_to_str, aread
from metagpt.utils.exceptions import handle_exception
@ -48,6 +48,13 @@ class ClassInfo(BaseModel):
methods: Dict[str, str] = Field(default_factory=dict)
class ClassRelationship(BaseModel):
src: str = ""
dest: str = ""
relationship: str = ""
label: Optional[str] = None
class RepoParser(BaseModel):
base_directory: Path = Field(default=None)
@ -62,7 +69,8 @@ class RepoParser(BaseModel):
file_info = RepoFileInfo(file=str(file_path.relative_to(self.base_directory)))
for node in tree:
info = RepoParser.node_to_str(node)
file_info.page_info.append(info)
if info:
file_info.page_info.append(info)
if isinstance(node, ast.ClassDef):
class_methods = [m.name for m in node.body if is_func(m)]
file_info.classes.append({"name": node.name, "methods": class_methods})
@ -111,7 +119,9 @@ class RepoParser(BaseModel):
self.generate_dataframe_structure(output_path)
@staticmethod
def node_to_str(node) -> (int, int, str, str | Tuple):
def node_to_str(node) -> CodeBlockInfo | None:
if isinstance(node, ast.Try):
return None
if any_to_str(node) == any_to_str(ast.Expr):
return CodeBlockInfo(
lineno=node.lineno,
@ -130,6 +140,7 @@ class RepoParser(BaseModel):
},
any_to_str(ast.If): RepoParser._parse_if,
any_to_str(ast.AsyncFunctionDef): lambda x: x.name,
any_to_str(ast.AnnAssign): lambda x: RepoParser._parse_variable(x.target),
}
func = mappings.get(any_to_str(node))
if func:
@ -165,22 +176,52 @@ class RepoParser(BaseModel):
@staticmethod
def _parse_if(n):
tokens = [RepoParser._parse_variable(n.test.left)]
for item in n.test.comparators:
tokens.append(RepoParser._parse_variable(item))
tokens = []
try:
if isinstance(n.test, ast.BoolOp):
tokens = []
for v in n.test.values:
tokens.extend(RepoParser._parse_if_compare(v))
return tokens
if isinstance(n.test, ast.Compare):
v = RepoParser._parse_variable(n.test.left)
if v:
tokens.append(v)
for item in n.test.comparators:
v = RepoParser._parse_variable(item)
if v:
tokens.append(v)
return tokens
except Exception as e:
logger.warning(e)
return tokens
@staticmethod
def _parse_if_compare(n):
if hasattr(n, "left"):
return RepoParser._parse_variable(n.left)
else:
return []
@staticmethod
def _parse_variable(node):
funcs = {
any_to_str(ast.Constant): lambda x: x.value,
any_to_str(ast.Name): lambda x: x.id,
any_to_str(ast.Attribute): lambda x: f"{x.value.id}.{x.attr}",
}
func = funcs.get(any_to_str(node))
if not func:
raise NotImplementedError(f"Not implement:{node}")
return func(node)
try:
funcs = {
any_to_str(ast.Constant): lambda x: x.value,
any_to_str(ast.Name): lambda x: x.id,
any_to_str(ast.Attribute): lambda x: f"{x.value.id}.{x.attr}"
if hasattr(x.value, "id")
else f"{x.attr}",
any_to_str(ast.Call): lambda x: RepoParser._parse_variable(x.func),
any_to_str(ast.Tuple): lambda x: "",
}
func = funcs.get(any_to_str(node))
if not func:
raise NotImplementedError(f"Not implement:{node}")
return func(node)
except Exception as e:
logger.warning(e)
raise e
@staticmethod
def _parse_assign(node):
@ -198,18 +239,21 @@ class RepoParser(BaseModel):
raise ValueError(f"{result}")
class_view_pathname = path / "classes.dot"
class_views = await self._parse_classes(class_view_pathname)
relationship_views = await self._parse_class_relationships(class_view_pathname)
packages_pathname = path / "packages.dot"
class_views = RepoParser._repair_namespaces(class_views=class_views, path=path)
class_views, relationship_views = RepoParser._repair_namespaces(
class_views=class_views, relationship_views=relationship_views, path=path
)
class_view_pathname.unlink(missing_ok=True)
packages_pathname.unlink(missing_ok=True)
return class_views
return class_views, relationship_views
async def _parse_classes(self, class_view_pathname):
class_views = []
if not class_view_pathname.exists():
return class_views
async with aiofiles.open(str(class_view_pathname), mode="r") as reader:
lines = await reader.readlines()
data = await aread(filename=class_view_pathname, encoding="utf-8")
lines = data.split("\n")
for line in lines:
package_name, info = RepoParser._split_class_line(line)
if not package_name:
@ -230,6 +274,19 @@ class RepoParser(BaseModel):
class_views.append(class_info)
return class_views
async def _parse_class_relationships(self, class_view_pathname) -> List[ClassRelationShip]:
relationship_views = []
if not class_view_pathname.exists():
return relationship_views
data = await aread(filename=class_view_pathname, encoding="utf-8")
lines = data.split("\n")
for line in lines:
relationship = RepoParser._split_relationship_line(line)
if not relationship:
continue
relationship_views.append(relationship)
return relationship_views
@staticmethod
def _split_class_line(line):
part_splitor = '" ['
@ -248,6 +305,40 @@ class RepoParser(BaseModel):
info = re.sub(r"<br[^>]*>", "\n", info)
return class_name, info
@staticmethod
def _split_relationship_line(line):
splitters = [" -> ", " [", "];"]
idxs = []
for tag in splitters:
if tag not in line:
return None
idxs.append(line.find(tag))
ret = ClassRelationship()
ret.src = line[0 : idxs[0]].strip('"')
ret.dest = line[idxs[0] + len(splitters[0]) : idxs[1]].strip('"')
properties = line[idxs[1] + len(splitters[1]) : idxs[2]].strip(" ")
mappings = {
'arrowhead="empty"': GENERALIZATION,
'arrowhead="diamond"': COMPOSITION,
'arrowhead="odiamond"': AGGREGATION,
}
for k, v in mappings.items():
if k in properties:
ret.relationship = v
if v != GENERALIZATION:
ret.label = RepoParser._get_label(properties)
break
return ret
@staticmethod
def _get_label(line):
tag = 'label="'
if tag not in line:
return ""
ix = line.find(tag)
eix = line.find('"', ix + len(tag))
return line[ix + len(tag) : eix]
@staticmethod
def _create_path_mapping(path: str | Path) -> Dict[str, str]:
mappings = {
@ -272,7 +363,9 @@ class RepoParser(BaseModel):
return mappings
@staticmethod
def _repair_namespaces(class_views: List[ClassInfo], path: str | Path) -> List[ClassInfo]:
def _repair_namespaces(
class_views: List[ClassInfo], relationship_views: List[ClassRelationship], path: str | Path
) -> (List[ClassInfo], List[ClassRelationShip]):
if not class_views:
return []
c = class_views[0]
@ -291,7 +384,12 @@ class RepoParser(BaseModel):
for c in class_views:
c.package = RepoParser._repair_ns(c.package, new_mappings)
return class_views
for i in range(len(relationship_views)):
v = relationship_views[i]
v.src = RepoParser._repair_ns(v.src, new_mappings)
v.dest = RepoParser._repair_ns(v.dest, new_mappings)
relationship_views[i] = v
return class_views, relationship_views
@staticmethod
def _repair_ns(package, mappings):

View file

@ -419,6 +419,10 @@ def concat_namespace(*args) -> str:
return ":".join(str(value) for value in args)
def split_namespace(ns_class_name: str) -> List[str]:
pass
def general_after_log(i: "loguru.Logger", sec_format: str = "%0.3f") -> typing.Callable[["RetryCallState"], None]:
"""
Generates a logging function to be used after a call is retried.

View file

@ -12,9 +12,9 @@ import json
from pathlib import Path
from typing import List
import aiofiles
import networkx
from metagpt.utils.common import aread, awrite
from metagpt.utils.graph_repository import SPO, GraphRepository
@ -55,12 +55,10 @@ class DiGraphRepository(GraphRepository):
if not path.exists():
path.mkdir(parents=True, exist_ok=True)
pathname = Path(path) / self.name
async with aiofiles.open(str(pathname.with_suffix(".json")), mode="w", encoding="utf-8") as writer:
await writer.write(data)
await awrite(filename=pathname.with_suffix(".json"), data=data, encoding="utf-8")
async def load(self, pathname: str | Path):
async with aiofiles.open(str(pathname), mode="r", encoding="utf-8") as reader:
data = await reader.read(-1)
data = await aread(filename=pathname, encoding="utf-8")
m = json.loads(data)
self._repo = networkx.node_link_graph(m)

View file

@ -13,19 +13,24 @@ from typing import List
from pydantic import BaseModel
from metagpt.repo_parser import ClassInfo, RepoFileInfo
from metagpt.repo_parser import ClassInfo, ClassRelationship, RepoFileInfo
from metagpt.utils.common import concat_namespace
class GraphKeyword:
IS = "is"
OF = "Of"
ON = "On"
CLASS = "class"
FUNCTION = "function"
HAS_FUNCTION = "has_function"
SOURCE_CODE = "source_code"
NULL = "<null>"
GLOBAL_VARIABLE = "global_variable"
CLASS_FUNCTION = "class_function"
CLASS_PROPERTY = "class_property"
HAS_CLASS_FUNCTION = "has_class_function"
HAS_CLASS_PROPERTY = "has_class_property"
HAS_CLASS = "has_class"
HAS_PAGE_INFO = "has_page_info"
HAS_CLASS_VIEW = "has_class_view"
@ -73,11 +78,13 @@ class GraphRepository(ABC):
await graph_db.insert(subject=file_info.file, predicate=GraphKeyword.IS, object_=file_type)
for c in file_info.classes:
class_name = c.get("name", "")
# file -> class
await graph_db.insert(
subject=file_info.file,
predicate=GraphKeyword.HAS_CLASS,
object_=concat_namespace(file_info.file, class_name),
)
# class detail
await graph_db.insert(
subject=concat_namespace(file_info.file, class_name),
predicate=GraphKeyword.IS,
@ -85,12 +92,22 @@ class GraphRepository(ABC):
)
methods = c.get("methods", [])
for fn in methods:
await graph_db.insert(
subject=concat_namespace(file_info.file, class_name),
predicate=GraphKeyword.HAS_CLASS_FUNCTION,
object_=concat_namespace(file_info.file, class_name, fn),
)
await graph_db.insert(
subject=concat_namespace(file_info.file, class_name, fn),
predicate=GraphKeyword.IS,
object_=GraphKeyword.CLASS_FUNCTION,
)
for f in file_info.functions:
# file -> function
await graph_db.insert(
subject=file_info.file, predicate=GraphKeyword.HAS_FUNCTION, object_=concat_namespace(file_info.file, f)
)
# function detail
await graph_db.insert(
subject=concat_namespace(file_info.file, f), predicate=GraphKeyword.IS, object_=GraphKeyword.FUNCTION
)
@ -105,13 +122,13 @@ class GraphRepository(ABC):
await graph_db.insert(
subject=concat_namespace(file_info.file, *code_block.tokens),
predicate=GraphKeyword.HAS_PAGE_INFO,
object_=code_block.json(ensure_ascii=False),
object_=code_block.model_dump_json(),
)
for k, v in code_block.properties.items():
await graph_db.insert(
subject=concat_namespace(file_info.file, k, v),
predicate=GraphKeyword.HAS_PAGE_INFO,
object_=code_block.json(ensure_ascii=False),
object_=code_block.model_dump_json(),
)
@staticmethod
@ -129,6 +146,13 @@ class GraphRepository(ABC):
object_=GraphKeyword.CLASS,
)
for vn, vt in c.attributes.items():
# class -> property
await graph_db.insert(
subject=c.package,
predicate=GraphKeyword.HAS_CLASS_PROPERTY,
object_=concat_namespace(c.package, vn),
)
# property detail
await graph_db.insert(
subject=concat_namespace(c.package, vn),
predicate=GraphKeyword.IS,
@ -138,6 +162,13 @@ class GraphRepository(ABC):
subject=concat_namespace(c.package, vn), predicate=GraphKeyword.HAS_TYPE_DESC, object_=vt
)
for fn, desc in c.methods.items():
# class -> function
await graph_db.insert(
subject=c.package,
predicate=GraphKeyword.HAS_CLASS_FUNCTION,
object_=concat_namespace(c.package, fn),
)
# function detail
await graph_db.insert(
subject=concat_namespace(c.package, fn),
predicate=GraphKeyword.IS,
@ -148,3 +179,19 @@ class GraphRepository(ABC):
predicate=GraphKeyword.HAS_ARGS_DESC,
object_=desc,
)
@staticmethod
async def update_graph_db_with_class_relationship_views(
graph_db: "GraphRepository", relationship_views: List[ClassRelationship]
):
for r in relationship_views:
await graph_db.insert(
subject=r.src, predicate=GraphKeyword.IS + r.relationship + GraphKeyword.OF, object_=r.dest
)
if not r.label:
continue
await graph_db.insert(
subject=r.src,
predicate=GraphKeyword.IS + r.relationship + GraphKeyword.ON,
object_=concat_namespace(r.dest, r.label),
)

View file

@ -16,7 +16,9 @@ from metagpt.llm import LLM
@pytest.mark.asyncio
async def test_rebuild():
action = RebuildClassView(name="RedBean", context=Path(__file__).parent.parent, llm=LLM())
action = RebuildClassView(
name="RedBean", context=str(Path(__file__).parent.parent.parent.parent / "metagpt"), llm=LLM()
)
await action.run()