diff --git a/metagpt/actions/action_node.py b/metagpt/actions/action_node.py
index 6f1215920..4ed8bf22e 100644
--- a/metagpt/actions/action_node.py
+++ b/metagpt/actions/action_node.py
@@ -39,7 +39,7 @@ SIMPLE_TEMPLATE = """
{constraint}
## action
-Fill in the above nodes based on the format example.
+Based on the 'context' content, fill in the {node_name} using the 'format example' format above."
"""
@@ -247,8 +247,13 @@ class ActionNode:
# FIXME: json instruction会带来格式问题,如:"Project name": "web_2048 # 项目名称使用下划线",
self.instruction = self.compile_instruction(to="markdown", mode=mode)
self.example = self.compile_example(to=to, tag="CONTENT", mode=mode)
+ node_name = "nodes" if template != SIMPLE_TEMPLATE else f'"{list(self.children.keys())[0]}" node'
prompt = template.format(
- context=context, example=self.example, instruction=self.instruction, constraint=CONSTRAINT
+ context=context,
+ example=self.example,
+ instruction=self.instruction,
+ constraint=CONSTRAINT,
+ node_name=node_name,
)
return prompt
@@ -302,6 +307,7 @@ class ActionNode:
mapping = self.get_mapping(mode)
class_name = f"{self.key}_AN"
+ print(prompt)
output = await self._aask_v1(prompt, class_name, mapping, format=to)
self.content = output.content
self.instruct_content = output.instruct_content
diff --git a/metagpt/actions/rebuild_class_view.py b/metagpt/actions/rebuild_class_view.py
new file mode 100644
index 000000000..6da3e2989
--- /dev/null
+++ b/metagpt/actions/rebuild_class_view.py
@@ -0,0 +1,68 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+"""
+@Time : 2023/12/19
+@Author : mashenquan
+@File : rebuild_class_view.py
+@Desc : Rebuild class view info
+"""
+import re
+from pathlib import Path
+
+from metagpt.actions import Action
+from metagpt.config import CONFIG
+from metagpt.const import CLASS_VIEW_FILE_REPO, GRAPH_REPO_FILE_REPO
+from metagpt.repo_parser import RepoParser
+from metagpt.utils.di_graph_repository import DiGraphRepository
+from metagpt.utils.graph_repository import GraphKeyword, GraphRepository
+
+
+class RebuildClassView(Action):
+ def __init__(self, name="", context=None, llm=None):
+ super().__init__(name=name, context=context, llm=llm)
+
+ async def run(self, with_messages=None, format=CONFIG.prompt_format):
+ graph_repo_pathname = CONFIG.git_repo.workdir / GRAPH_REPO_FILE_REPO / CONFIG.git_repo.workdir.name
+ graph_db = await DiGraphRepository.load_from(str(graph_repo_pathname.with_suffix(".json")))
+ repo_parser = RepoParser(base_directory=self.context)
+ class_views = await repo_parser.rebuild_class_views(path=Path(self.context)) # use pylint
+ await GraphRepository.update_graph_db_with_class_views(graph_db, class_views)
+ symbols = repo_parser.generate_symbols() # use ast
+ for file_info in symbols:
+ await GraphRepository.update_graph_db_with_file_info(graph_db, file_info)
+ await self._create_mermaid_class_view(graph_db=graph_db)
+ await self._save(graph_db=graph_db)
+
+ async def _create_mermaid_class_view(self, graph_db):
+ pass
+ # dataset = await graph_db.select(subject=concat_namespace(filename, class_name), predicate=GraphKeyword.HAS_PAGE_INFO)
+ # if not dataset:
+ # logger.warning(f"No page info for {concat_namespace(filename, class_name)}")
+ # return
+ # code_block_info = CodeBlockInfo.parse_raw(dataset[0].object_)
+ # src_code = await read_file_block(filename=Path(self.context) / filename, lineno=code_block_info.lineno, end_lineno=code_block_info.end_lineno)
+ # code_type = ""
+ # dataset = await graph_db.select(subject=filename, predicate=GraphKeyword.IS)
+ # for spo in dataset:
+ # if spo.object_ in ["javascript", "python"]:
+ # code_type = spo.object_
+ # break
+
+ # try:
+ # node = await REBUILD_CLASS_VIEW_NODE.fill(context=f"```{code_type}\n{src_code}\n```", llm=self.llm, to=format)
+ # class_view = node.instruct_content.dict()["Class View"]
+ # except Exception as e:
+ # class_view = RepoParser.rebuild_class_view(src_code, code_type)
+ # await graph_db.insert(subject=concat_namespace(filename, class_name), predicate=GraphKeyword.HAS_CLASS_VIEW, object_=class_view)
+ # logger.info(f"{concat_namespace(filename, class_name)} {GraphKeyword.HAS_CLASS_VIEW} {class_view}")
+
+ async def _save(self, graph_db):
+ class_view_file_repo = CONFIG.git_repo.new_file_repository(relative_path=CLASS_VIEW_FILE_REPO)
+ dataset = await graph_db.select(predicate=GraphKeyword.HAS_CLASS_VIEW)
+ all_class_view = []
+ for spo in dataset:
+ title = f"---\ntitle: {spo.subject}\n---\n"
+ filename = re.sub(r"[/:]", "_", spo.subject) + ".mmd"
+ await class_view_file_repo.save(filename=filename, content=title + spo.object_)
+ all_class_view.append(spo.object_)
+ await class_view_file_repo.save(filename="all.mmd", content="\n".join(all_class_view))
diff --git a/metagpt/actions/rebuild_class_view_an.py b/metagpt/actions/rebuild_class_view_an.py
new file mode 100644
index 000000000..da32a9b5e
--- /dev/null
+++ b/metagpt/actions/rebuild_class_view_an.py
@@ -0,0 +1,33 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+"""
+@Time : 2023/12/19
+@Author : mashenquan
+@File : rebuild_class_view_an.py
+@Desc : Defines `ActionNode` objects used by rebuild_class_view.py
+"""
+from metagpt.actions.action_node import ActionNode
+
+CLASS_SOURCE_CODE_BLOCK = ActionNode(
+ key="Class View",
+ expected_type=str,
+ instruction='Generate the mermaid class diagram corresponding to source code in "context."',
+ example="""
+ classDiagram
+ class A {
+ -int x
+ +int y
+ -int speed
+ -int direction
+ +__init__(x: int, y: int, speed: int, direction: int)
+ +change_direction(new_direction: int) None
+ +move() None
+ }
+ """,
+)
+
+REBUILD_CLASS_VIEW_NODES = [
+ CLASS_SOURCE_CODE_BLOCK,
+]
+
+REBUILD_CLASS_VIEW_NODE = ActionNode.from_children("RebuildClassView", REBUILD_CLASS_VIEW_NODES)
diff --git a/metagpt/const.py b/metagpt/const.py
index fcb3a2b3e..53f797001 100644
--- a/metagpt/const.py
+++ b/metagpt/const.py
@@ -99,6 +99,8 @@ CODE_SUMMARIES_FILE_REPO = "docs/code_summaries"
CODE_SUMMARIES_PDF_FILE_REPO = "resources/code_summaries"
RESOURCES_FILE_REPO = "resources"
SD_OUTPUT_FILE_REPO = "resources/SD_Output"
+GRAPH_REPO_FILE_REPO = "docs/graph_repo"
+CLASS_VIEW_FILE_REPO = "docs/class_views"
YAPI_URL = "http://yapi.deepwisdomai.com/"
diff --git a/metagpt/repo_parser.py b/metagpt/repo_parser.py
index 03cf7be79..ff34257a6 100644
--- a/metagpt/repo_parser.py
+++ b/metagpt/repo_parser.py
@@ -9,10 +9,13 @@ from __future__ import annotations
import ast
import json
+import re
+import subprocess
from pathlib import Path
from pprint import pformat
-from typing import List
+from typing import Dict, List, Optional, Tuple
+import aiofiles
import pandas as pd
from pydantic import BaseModel, Field
@@ -22,6 +25,29 @@ from metagpt.utils.common import any_to_str
from metagpt.utils.exceptions import handle_exception
+class RepoFileInfo(BaseModel):
+ file: str
+ classes: List = Field(default_factory=list)
+ functions: List = Field(default_factory=list)
+ globals: List = Field(default_factory=list)
+ page_info: List = Field(default_factory=list)
+
+
+class CodeBlockInfo(BaseModel):
+ lineno: int
+ end_lineno: int
+ type_name: str
+ tokens: List = Field(default_factory=list)
+ properties: Dict = Field(default_factory=dict)
+
+
+class ClassInfo(BaseModel):
+ name: str
+ package: Optional[str] = None
+ attributes: Dict[str, str] = Field(default_factory=dict)
+ methods: Dict[str, str] = Field(default_factory=dict)
+
+
class RepoParser(BaseModel):
base_directory: Path = Field(default=None)
@@ -31,32 +57,24 @@ class RepoParser(BaseModel):
"""Parse a Python file in the repository."""
return ast.parse(file_path.read_text()).body
- def extract_class_and_function_info(self, tree, file_path):
+ def extract_class_and_function_info(self, tree, file_path) -> RepoFileInfo:
"""Extract class, function, and global variable information from the AST."""
- file_info = {
- "file": str(file_path.relative_to(self.base_directory)),
- "classes": [],
- "functions": [],
- "globals": [],
- }
-
- page_info = []
+ file_info = RepoFileInfo(file=str(file_path.relative_to(self.base_directory)))
for node in tree:
info = RepoParser.node_to_str(node)
- page_info.append(info)
+ file_info.page_info.append(info)
if isinstance(node, ast.ClassDef):
class_methods = [m.name for m in node.body if is_func(m)]
- file_info["classes"].append({"name": node.name, "methods": class_methods})
+ file_info.classes.append({"name": node.name, "methods": class_methods})
elif is_func(node):
- file_info["functions"].append(node.name)
+ file_info.functions.append(node.name)
elif isinstance(node, (ast.Assign, ast.AnnAssign)):
for target in node.targets if isinstance(node, ast.Assign) else [node.target]:
if isinstance(target, ast.Name):
- file_info["globals"].append(target.id)
- file_info["page_info"] = page_info
+ file_info.globals.append(target.id)
return file_info
- def generate_symbols(self):
+ def generate_symbols(self) -> List[RepoFileInfo]:
files_classes = []
directory = self.base_directory
@@ -93,37 +111,213 @@ class RepoParser(BaseModel):
self.generate_dataframe_structure(output_path)
@staticmethod
- def node_to_str(node) -> (int, int, str, str | List):
- def _parse_name(n):
- if n.asname:
- return f"{n.name} as {n.asname}"
- return n.name
-
+ def node_to_str(node) -> (int, int, str, str | Tuple):
if any_to_str(node) == any_to_str(ast.Expr):
- return node.lineno, node.end_lineno, any_to_str(node), RepoParser._parse_expr(node)
+ return CodeBlockInfo(
+ lineno=node.lineno,
+ end_lineno=node.end_lineno,
+ type_name=any_to_str(node),
+ tokens=RepoParser._parse_expr(node),
+ )
mappings = {
- any_to_str(ast.Import): lambda x: [_parse_name(n) for n in x.names],
- any_to_str(ast.Assign): lambda x: [n.id for n in x.targets],
+ any_to_str(ast.Import): lambda x: [RepoParser._parse_name(n) for n in x.names],
+ any_to_str(ast.Assign): RepoParser._parse_assign,
any_to_str(ast.ClassDef): lambda x: x.name,
any_to_str(ast.FunctionDef): lambda x: x.name,
- any_to_str(ast.ImportFrom): lambda x: {"module": x.module, "names": [_parse_name(n) for n in x.names]},
- any_to_str(ast.If): lambda x: x.test.left.id,
+ any_to_str(ast.ImportFrom): lambda x: {
+ "module": x.module,
+ "names": [RepoParser._parse_name(n) for n in x.names],
+ },
+ any_to_str(ast.If): RepoParser._parse_if,
+ any_to_str(ast.AsyncFunctionDef): lambda x: x.name,
}
func = mappings.get(any_to_str(node))
if func:
- return node.lineno, node.end_lineno, any_to_str(node), func(node)
- return node.lineno, node.end_lineno, any_to_str(node), None
+ code_block = CodeBlockInfo(lineno=node.lineno, end_lineno=node.end_lineno, type_name=any_to_str(node))
+ val = func(node)
+ if isinstance(val, dict):
+ code_block.properties = val
+ elif isinstance(val, list):
+ code_block.tokens = val
+ elif isinstance(val, str):
+ code_block.tokens = [val]
+ else:
+ raise NotImplementedError(f"Not implement:{val}")
+ return code_block
+ raise NotImplementedError(f"Not implement code block:{node.lineno}, {node.end_lineno}, {any_to_str(node)}")
@staticmethod
- def _parse_expr(node) -> (int, int, str, str | List):
- if isinstance(node.value, ast.Constant):
- return any_to_str(ast.Constant), node.value.value
- if isinstance(node.value, ast.Call):
- if isinstance(node.value.func, ast.Attribute):
- return any_to_str(ast.Call), f"{node.value.func.value.id}.{node.value.func.attr}"
- if isinstance(node.value.func, ast.Name):
- return any_to_str(ast.Call), node.value.func.id
- return any_to_str(node.value), None
+ def _parse_expr(node) -> List:
+ funcs = {
+ any_to_str(ast.Constant): lambda x: [any_to_str(x.value), RepoParser._parse_variable(x.value)],
+ any_to_str(ast.Call): lambda x: [any_to_str(x.value), RepoParser._parse_variable(x.value.func)],
+ }
+ func = funcs.get(any_to_str(node.value))
+ if func:
+ return func(node)
+ raise NotImplementedError(f"Not implement: {node.value}")
+
+ @staticmethod
+ def _parse_name(n):
+ if n.asname:
+ return f"{n.name} as {n.asname}"
+ return n.name
+
+ @staticmethod
+ def _parse_if(n):
+ tokens = [RepoParser._parse_variable(n.test.left)]
+ for item in n.test.comparators:
+ tokens.append(RepoParser._parse_variable(item))
+ return tokens
+
+ @staticmethod
+ def _parse_variable(node):
+ funcs = {
+ any_to_str(ast.Constant): lambda x: x.value,
+ any_to_str(ast.Name): lambda x: x.id,
+ any_to_str(ast.Attribute): lambda x: f"{x.value.id}.{x.attr}",
+ }
+ func = funcs.get(any_to_str(node))
+ if not func:
+ raise NotImplementedError(f"Not implement:{node}")
+ return func(node)
+
+ @staticmethod
+ def _parse_assign(node):
+ return [RepoParser._parse_variable(t) for t in node.targets]
+
+ async def rebuild_class_views(self, path: str | Path = None):
+ if not path:
+ path = self.base_directory
+ path = Path(path)
+ if not path.exists():
+ return
+ command = f"pyreverse {str(path)} -o dot"
+ result = subprocess.run(command, shell=True, check=True, cwd=str(path))
+ if result.returncode != 0:
+ raise ValueError(f"{result}")
+ class_view_pathname = path / "classes.dot"
+ class_views = await self._parse_classes(class_view_pathname)
+ packages_pathname = path / "packages.dot"
+ class_views = RepoParser._repair_namespaces(class_views=class_views, path=path)
+ class_view_pathname.unlink(missing_ok=True)
+ packages_pathname.unlink(missing_ok=True)
+ return class_views
+
+ async def _parse_classes(self, class_view_pathname):
+ class_views = []
+ if not class_view_pathname.exists():
+ return class_views
+ async with aiofiles.open(str(class_view_pathname), mode="r") as reader:
+ lines = await reader.readlines()
+ for line in lines:
+ package_name, info = RepoParser._split_class_line(line)
+ if not package_name:
+ continue
+ class_name, members, functions = re.split(r"(?"
+ if begin_flag not in left or end_flag not in left:
+ return None, None
+ bix = left.find(begin_flag)
+ eix = left.rfind(end_flag)
+ info = left[bix + len(begin_flag) : eix]
+ info = re.sub(r"
]*>", "\n", info)
+ return class_name, info
+
+ @staticmethod
+ def _create_path_mapping(path: str | Path) -> Dict[str, str]:
+ mappings = {
+ str(path).replace("/", "."): str(path),
+ }
+ files = []
+ try:
+ directory_path = Path(path)
+ if not directory_path.exists():
+ return mappings
+ for file_path in directory_path.iterdir():
+ if file_path.is_file():
+ files.append(str(file_path))
+ else:
+ subfolder_files = RepoParser._create_path_mapping(path=file_path)
+ mappings.update(subfolder_files)
+ except Exception as e:
+ logger.error(f"Error: {e}")
+ for f in files:
+ mappings[str(Path(f).with_suffix("")).replace("/", ".")] = str(f)
+
+ return mappings
+
+ @staticmethod
+ def _repair_namespaces(class_views: List[ClassInfo], path: str | Path) -> List[ClassInfo]:
+ if not class_views:
+ return []
+ c = class_views[0]
+ full_key = str(path).lstrip("/").replace("/", ".")
+ root_namespace = RepoParser._find_root(full_key, c.package)
+ root_path = root_namespace.replace(".", "/")
+
+ mappings = RepoParser._create_path_mapping(path=path)
+ new_mappings = {}
+ ix_root_namespace = len(root_namespace)
+ ix_root_path = len(root_path)
+ for k, v in mappings.items():
+ nk = k[ix_root_namespace:]
+ nv = v[ix_root_path:]
+ new_mappings[nk] = nv
+
+ for c in class_views:
+ c.package = RepoParser._repair_ns(c.package, new_mappings)
+ return class_views
+
+ @staticmethod
+ def _repair_ns(package, mappings):
+ file_ns = package
+ while file_ns != "":
+ if file_ns not in mappings:
+ ix = file_ns.rfind(".")
+ file_ns = file_ns[0:ix]
+ continue
+ break
+ internal_ns = package[ix + 1 :]
+ ns = mappings[file_ns] + ":" + internal_ns.replace(".", ":")
+ return ns
+
+ @staticmethod
+ def _find_root(full_key, package) -> str:
+ left = full_key
+ while left != "":
+ if left in package:
+ break
+ if "." not in left:
+ break
+ ix = left.find(".")
+ left = left[ix + 1 :]
+ ix = full_key.rfind(left)
+ return "." + full_key[0:ix]
def is_func(node):
diff --git a/metagpt/utils/common.py b/metagpt/utils/common.py
index 8fa729556..a5d2100cc 100644
--- a/metagpt/utils/common.py
+++ b/metagpt/utils/common.py
@@ -17,8 +17,8 @@ import inspect
import os
import platform
import re
-import typing
-from typing import List, Tuple, Union
+from pathlib import Path
+from typing import Callable, List, Tuple, Union
import aiofiles
import loguru
@@ -332,7 +332,7 @@ def get_class_name(cls) -> str:
return f"{cls.__module__}.{cls.__name__}"
-def any_to_str(val: str | typing.Callable) -> str:
+def any_to_str(val: str | Callable) -> str:
"""Return the class name or the class name of the object, or 'val' if it's a string type."""
if isinstance(val, str):
return val
@@ -443,3 +443,20 @@ async def aread(file_path: str) -> str:
async with aiofiles.open(str(file_path), mode="r") as reader:
content = await reader.read()
return content
+
+
+async def read_file_block(filename: str | Path, lineno: int, end_lineno: int):
+ if not Path(filename).exists():
+ return ""
+ lines = []
+ async with aiofiles.open(str(filename), mode="r") as reader:
+ ix = 0
+ while ix < end_lineno:
+ ix += 1
+ line = await reader.readline()
+ if ix < lineno:
+ continue
+ if ix > end_lineno:
+ break
+ lines.append(line)
+ return "".join(lines)
diff --git a/metagpt/utils/di_graph_repository.py b/metagpt/utils/di_graph_repository.py
index 9bbd38d5f..08f4327fa 100644
--- a/metagpt/utils/di_graph_repository.py
+++ b/metagpt/utils/di_graph_repository.py
@@ -10,11 +10,12 @@ from __future__ import annotations
import json
from pathlib import Path
+from typing import List
import aiofiles
import networkx
-from metagpt.utils.graph_repository import GraphRepository
+from metagpt.utils.graph_repository import SPO, GraphRepository
class DiGraphRepository(GraphRepository):
@@ -31,6 +32,18 @@ class DiGraphRepository(GraphRepository):
async def update(self, subject: str, predicate: str, object_: str):
pass
+ async def select(self, subject: str = None, predicate: str = None, object_: str = None) -> List[SPO]:
+ result = []
+ for s, o, p in self._repo.edges(data="predicate"):
+ if subject and subject != s:
+ continue
+ if predicate and predicate != p:
+ continue
+ if object_ and object_ != o:
+ continue
+ result.append(SPO(subject=s, predicate=p, object_=o))
+ return result
+
def json(self) -> str:
m = networkx.node_link_data(self._repo)
data = json.dumps(m)
@@ -53,10 +66,12 @@ class DiGraphRepository(GraphRepository):
@staticmethod
async def load_from(pathname: str | Path) -> GraphRepository:
- name = Path(pathname).with_suffix("").name
- root = Path(pathname).parent
+ pathname = Path(pathname)
+ name = pathname.with_suffix("").name
+ root = pathname.parent
graph = DiGraphRepository(name=name, root=root)
- await graph.load(pathname=pathname)
+ if pathname.exists():
+ await graph.load(pathname=pathname)
return graph
@property
diff --git a/metagpt/utils/graph_repository.py b/metagpt/utils/graph_repository.py
index 600575b4e..37da3dee4 100644
--- a/metagpt/utils/graph_repository.py
+++ b/metagpt/utils/graph_repository.py
@@ -6,18 +6,38 @@
@File : graph_repository.py
@Desc : Superclass for graph repository.
"""
+
from abc import ABC, abstractmethod
-from enum import Enum
+from pathlib import Path
+from typing import List
+
+from pydantic import BaseModel
+
+from metagpt.repo_parser import ClassInfo, RepoFileInfo
+from metagpt.utils.common import concat_namespace
-class GraphKeyword(Enum):
+class GraphKeyword:
IS = "is"
CLASS = "class"
FUNCTION = "function"
+ SOURCE_CODE = "source_code"
+ NULL = ""
GLOBAL_VARIABLE = "global_variable"
CLASS_FUNCTION = "class_function"
CLASS_PROPERTY = "class_property"
HAS_CLASS = "has_class"
+ HAS_PAGE_INFO = "has_page_info"
+ HAS_CLASS_VIEW = "has_class_view"
+ HAS_SEQUENCE_VIEW = "has_sequence_view"
+ HAS_ARGS_DESC = "has_args_desc"
+ HAS_TYPE_DESC = "has_type_desc"
+
+
+class SPO(BaseModel):
+ subject: str
+ predicate: str
+ object_: str
class GraphRepository(ABC):
@@ -37,6 +57,94 @@ class GraphRepository(ABC):
async def update(self, subject: str, predicate: str, object_: str):
pass
+ @abstractmethod
+ async def select(self, subject: str = None, predicate: str = None, object_: str = None) -> List[SPO]:
+ pass
+
@property
def name(self) -> str:
return self._repo_name
+
+ @staticmethod
+ async def update_graph_db_with_file_info(graph_db: "GraphRepository", file_info: RepoFileInfo):
+ await graph_db.insert(subject=file_info.file, predicate=GraphKeyword.IS, object_=GraphKeyword.SOURCE_CODE)
+ file_types = {".py": "python", ".js": "javascript"}
+ file_type = file_types.get(Path(file_info.file).suffix, GraphKeyword.NULL)
+ await graph_db.insert(subject=file_info.file, predicate=GraphKeyword.IS, object_=file_type)
+ for c in file_info.classes:
+ class_name = c.get("name", "")
+ await graph_db.insert(
+ subject=file_info.file,
+ predicate=GraphKeyword.HAS_CLASS,
+ object_=concat_namespace(file_info.file, class_name),
+ )
+ await graph_db.insert(
+ subject=concat_namespace(file_info.file, class_name),
+ predicate=GraphKeyword.IS,
+ object_=GraphKeyword.CLASS,
+ )
+ methods = c.get("methods", [])
+ for fn in methods:
+ await graph_db.insert(
+ subject=concat_namespace(file_info.file, class_name, fn),
+ predicate=GraphKeyword.IS,
+ object_=GraphKeyword.CLASS_FUNCTION,
+ )
+ for f in file_info.functions:
+ await graph_db.insert(
+ subject=concat_namespace(file_info.file, f), predicate=GraphKeyword.IS, object_=GraphKeyword.FUNCTION
+ )
+ for g in file_info.globals:
+ await graph_db.insert(
+ subject=concat_namespace(file_info.file, g),
+ predicate=GraphKeyword.IS,
+ object_=GraphKeyword.GLOBAL_VARIABLE,
+ )
+ for code_block in file_info.page_info:
+ if code_block.tokens:
+ await graph_db.insert(
+ subject=concat_namespace(file_info.file, *code_block.tokens),
+ predicate=GraphKeyword.HAS_PAGE_INFO,
+ object_=code_block.json(ensure_ascii=False),
+ )
+ for k, v in code_block.properties.items():
+ await graph_db.insert(
+ subject=concat_namespace(file_info.file, k, v),
+ predicate=GraphKeyword.HAS_PAGE_INFO,
+ object_=code_block.json(ensure_ascii=False),
+ )
+
+ @staticmethod
+ async def update_graph_db_with_class_views(graph_db: "GraphRepository", class_views: List[ClassInfo]):
+ for c in class_views:
+ filename, class_name = c.package.split(":", 1)
+ await graph_db.insert(subject=filename, predicate=GraphKeyword.IS, object_=GraphKeyword.SOURCE_CODE)
+ file_types = {".py": "python", ".js": "javascript"}
+ file_type = file_types.get(Path(filename).suffix, GraphKeyword.NULL)
+ await graph_db.insert(subject=filename, predicate=GraphKeyword.IS, object_=file_type)
+ await graph_db.insert(subject=filename, predicate=GraphKeyword.HAS_CLASS, object_=class_name)
+ await graph_db.insert(
+ subject=c.package,
+ predicate=GraphKeyword.IS,
+ object_=GraphKeyword.CLASS,
+ )
+ for vn, vt in c.attributes.items():
+ await graph_db.insert(
+ subject=concat_namespace(c.package, vn),
+ predicate=GraphKeyword.IS,
+ object_=GraphKeyword.CLASS_PROPERTY,
+ )
+ await graph_db.insert(
+ subject=concat_namespace(c.package, vn), predicate=GraphKeyword.HAS_TYPE_DESC, object_=vt
+ )
+ for fn, desc in c.methods.items():
+ await graph_db.insert(
+ subject=concat_namespace(c.package, fn),
+ predicate=GraphKeyword.IS,
+ object_=GraphKeyword.CLASS_FUNCTION,
+ )
+ await graph_db.insert(
+ subject=concat_namespace(c.package, fn),
+ predicate=GraphKeyword.HAS_ARGS_DESC,
+ object_=desc,
+ )
diff --git a/requirements.txt b/requirements.txt
index 4310aec6c..c4e674569 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -57,4 +57,5 @@ socksio~=1.0.0
gitignore-parser==0.1.9
connexion[swagger-ui]
websockets~=12.0
-networkx~=3.2.1
\ No newline at end of file
+networkx~=3.2.1
+pylint~=3.0.3
\ No newline at end of file
diff --git a/tests/metagpt/actions/test_rebuild_class_view.py b/tests/metagpt/actions/test_rebuild_class_view.py
new file mode 100644
index 000000000..955c6ae3b
--- /dev/null
+++ b/tests/metagpt/actions/test_rebuild_class_view.py
@@ -0,0 +1,24 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+"""
+@Time : 2023/12/20
+@Author : mashenquan
+@File : test_rebuild_class_view.py
+@Desc : Unit tests for rebuild_class_view.py
+"""
+from pathlib import Path
+
+import pytest
+
+from metagpt.actions.rebuild_class_view import RebuildClassView
+from metagpt.llm import LLM
+
+
+@pytest.mark.asyncio
+async def test_rebuild():
+ action = RebuildClassView(name="RedBean", context=Path(__file__).parent.parent, llm=LLM())
+ await action.run()
+
+
+if __name__ == "__main__":
+ pytest.main([__file__, "-s"])
diff --git a/tests/metagpt/test_repo_parser.py b/tests/metagpt/test_repo_parser.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/tests/metagpt/utils/test_di_graph_repository.py b/tests/metagpt/utils/test_di_graph_repository.py
index ec2cb4d01..0a8011e51 100644
--- a/tests/metagpt/utils/test_di_graph_repository.py
+++ b/tests/metagpt/utils/test_di_graph_repository.py
@@ -14,9 +14,8 @@ from pydantic import BaseModel
from metagpt.const import DEFAULT_WORKSPACE_ROOT
from metagpt.repo_parser import RepoParser
-from metagpt.utils.common import concat_namespace
from metagpt.utils.di_graph_repository import DiGraphRepository
-from metagpt.utils.graph_repository import GraphKeyword
+from metagpt.utils.graph_repository import GraphRepository
@pytest.mark.asyncio
@@ -57,23 +56,7 @@ async def test_js_parser():
repo_parser = RepoParser(base_directory=data.path)
symbols = repo_parser.generate_symbols()
for s in symbols:
- ns = s.get("file", "")
- for c in s.get("classes", []):
- await graph.insert(
- subject=concat_namespace(ns, c), predicate=GraphKeyword.IS.value, object_=GraphKeyword.CLASS.value
- )
- for f in s.get("functions", []):
- await graph.insert(
- subject=concat_namespace(ns, f),
- predicate=GraphKeyword.IS.value,
- object_=GraphKeyword.FUNCTION.value,
- )
- for g in s.get("globals", []):
- await graph.insert(
- subject=concat_namespace(ns, g),
- predicate=GraphKeyword.IS.value,
- object_=GraphKeyword.GLOBAL_VARIABLE.value,
- )
+ await GraphRepository.update_graph_db(graph_db=graph, file_info=s)
data = graph.json()
assert data
@@ -85,35 +68,14 @@ async def test_codes():
graph = DiGraphRepository(name="test", root=path)
symbols = repo_parser.generate_symbols()
- for s in symbols:
- ns = s.get("file", "")
- for c in s.get("classes", []):
- class_name = c.get("name", "")
- await graph.insert(
- subject=ns, predicate=GraphKeyword.HAS_CLASS.value, object_=concat_namespace(ns, class_name)
- )
- await graph.insert(
- subject=concat_namespace(ns, class_name),
- predicate=GraphKeyword.IS.value,
- object_=GraphKeyword.CLASS.value,
- )
- methods = c.get("methods", [])
- for fn in methods:
- await graph.insert(
- subject=concat_namespace(ns, class_name, fn),
- predicate=GraphKeyword.IS.value,
- object_=GraphKeyword.CLASS_FUNCTION.value,
- )
- for f in s.get("functions", []):
- await graph.insert(
- subject=concat_namespace(ns, f), predicate=GraphKeyword.IS.value, object_=GraphKeyword.FUNCTION.value
- )
- for g in s.get("globals", []):
- await graph.insert(
- subject=concat_namespace(ns, g),
- predicate=GraphKeyword.IS.value,
- object_=GraphKeyword.GLOBAL_VARIABLE.value,
- )
+ for file_info in symbols:
+ for code_block in file_info.page_info:
+ try:
+ val = code_block.json(ensure_ascii=False)
+ assert val
+ except TypeError as e:
+ assert not e
+ await GraphRepository.update_graph_db(graph_db=graph, file_info=file_info)
data = graph.json()
assert data
print(data)