feat: +pylint class view

This commit is contained in:
莘权 马 2023-12-21 12:09:39 +08:00
parent 81b1e5bb1c
commit 863a30e903
12 changed files with 528 additions and 98 deletions

View file

@ -39,7 +39,7 @@ SIMPLE_TEMPLATE = """
{constraint}
## action
Fill in the above nodes based on the format example.
Based on the 'context' content, fill in the {node_name} using the 'format example' format above."
"""
@ -247,8 +247,13 @@ class ActionNode:
# FIXME: json instruction会带来格式问题"Project name": "web_2048 # 项目名称使用下划线",
self.instruction = self.compile_instruction(to="markdown", mode=mode)
self.example = self.compile_example(to=to, tag="CONTENT", mode=mode)
node_name = "nodes" if template != SIMPLE_TEMPLATE else f'"{list(self.children.keys())[0]}" node'
prompt = template.format(
context=context, example=self.example, instruction=self.instruction, constraint=CONSTRAINT
context=context,
example=self.example,
instruction=self.instruction,
constraint=CONSTRAINT,
node_name=node_name,
)
return prompt
@ -302,6 +307,7 @@ class ActionNode:
mapping = self.get_mapping(mode)
class_name = f"{self.key}_AN"
print(prompt)
output = await self._aask_v1(prompt, class_name, mapping, format=to)
self.content = output.content
self.instruct_content = output.instruct_content

View file

@ -0,0 +1,68 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2023/12/19
@Author : mashenquan
@File : rebuild_class_view.py
@Desc : Rebuild class view info
"""
import re
from pathlib import Path
from metagpt.actions import Action
from metagpt.config import CONFIG
from metagpt.const import CLASS_VIEW_FILE_REPO, GRAPH_REPO_FILE_REPO
from metagpt.repo_parser import RepoParser
from metagpt.utils.di_graph_repository import DiGraphRepository
from metagpt.utils.graph_repository import GraphKeyword, GraphRepository
class RebuildClassView(Action):
def __init__(self, name="", context=None, llm=None):
super().__init__(name=name, context=context, llm=llm)
async def run(self, with_messages=None, format=CONFIG.prompt_format):
graph_repo_pathname = CONFIG.git_repo.workdir / GRAPH_REPO_FILE_REPO / CONFIG.git_repo.workdir.name
graph_db = await DiGraphRepository.load_from(str(graph_repo_pathname.with_suffix(".json")))
repo_parser = RepoParser(base_directory=self.context)
class_views = await repo_parser.rebuild_class_views(path=Path(self.context)) # use pylint
await GraphRepository.update_graph_db_with_class_views(graph_db, class_views)
symbols = repo_parser.generate_symbols() # use ast
for file_info in symbols:
await GraphRepository.update_graph_db_with_file_info(graph_db, file_info)
await self._create_mermaid_class_view(graph_db=graph_db)
await self._save(graph_db=graph_db)
async def _create_mermaid_class_view(self, graph_db):
pass
# dataset = await graph_db.select(subject=concat_namespace(filename, class_name), predicate=GraphKeyword.HAS_PAGE_INFO)
# if not dataset:
# logger.warning(f"No page info for {concat_namespace(filename, class_name)}")
# return
# code_block_info = CodeBlockInfo.parse_raw(dataset[0].object_)
# src_code = await read_file_block(filename=Path(self.context) / filename, lineno=code_block_info.lineno, end_lineno=code_block_info.end_lineno)
# code_type = ""
# dataset = await graph_db.select(subject=filename, predicate=GraphKeyword.IS)
# for spo in dataset:
# if spo.object_ in ["javascript", "python"]:
# code_type = spo.object_
# break
# try:
# node = await REBUILD_CLASS_VIEW_NODE.fill(context=f"```{code_type}\n{src_code}\n```", llm=self.llm, to=format)
# class_view = node.instruct_content.dict()["Class View"]
# except Exception as e:
# class_view = RepoParser.rebuild_class_view(src_code, code_type)
# await graph_db.insert(subject=concat_namespace(filename, class_name), predicate=GraphKeyword.HAS_CLASS_VIEW, object_=class_view)
# logger.info(f"{concat_namespace(filename, class_name)} {GraphKeyword.HAS_CLASS_VIEW} {class_view}")
async def _save(self, graph_db):
class_view_file_repo = CONFIG.git_repo.new_file_repository(relative_path=CLASS_VIEW_FILE_REPO)
dataset = await graph_db.select(predicate=GraphKeyword.HAS_CLASS_VIEW)
all_class_view = []
for spo in dataset:
title = f"---\ntitle: {spo.subject}\n---\n"
filename = re.sub(r"[/:]", "_", spo.subject) + ".mmd"
await class_view_file_repo.save(filename=filename, content=title + spo.object_)
all_class_view.append(spo.object_)
await class_view_file_repo.save(filename="all.mmd", content="\n".join(all_class_view))

View file

@ -0,0 +1,33 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2023/12/19
@Author : mashenquan
@File : rebuild_class_view_an.py
@Desc : Defines `ActionNode` objects used by rebuild_class_view.py
"""
from metagpt.actions.action_node import ActionNode
CLASS_SOURCE_CODE_BLOCK = ActionNode(
key="Class View",
expected_type=str,
instruction='Generate the mermaid class diagram corresponding to source code in "context."',
example="""
classDiagram
class A {
-int x
+int y
-int speed
-int direction
+__init__(x: int, y: int, speed: int, direction: int)
+change_direction(new_direction: int) None
+move() None
}
""",
)
REBUILD_CLASS_VIEW_NODES = [
CLASS_SOURCE_CODE_BLOCK,
]
REBUILD_CLASS_VIEW_NODE = ActionNode.from_children("RebuildClassView", REBUILD_CLASS_VIEW_NODES)

View file

@ -99,6 +99,8 @@ CODE_SUMMARIES_FILE_REPO = "docs/code_summaries"
CODE_SUMMARIES_PDF_FILE_REPO = "resources/code_summaries"
RESOURCES_FILE_REPO = "resources"
SD_OUTPUT_FILE_REPO = "resources/SD_Output"
GRAPH_REPO_FILE_REPO = "docs/graph_repo"
CLASS_VIEW_FILE_REPO = "docs/class_views"
YAPI_URL = "http://yapi.deepwisdomai.com/"

View file

@ -9,10 +9,13 @@ from __future__ import annotations
import ast
import json
import re
import subprocess
from pathlib import Path
from pprint import pformat
from typing import List
from typing import Dict, List, Optional, Tuple
import aiofiles
import pandas as pd
from pydantic import BaseModel, Field
@ -22,6 +25,29 @@ from metagpt.utils.common import any_to_str
from metagpt.utils.exceptions import handle_exception
class RepoFileInfo(BaseModel):
file: str
classes: List = Field(default_factory=list)
functions: List = Field(default_factory=list)
globals: List = Field(default_factory=list)
page_info: List = Field(default_factory=list)
class CodeBlockInfo(BaseModel):
lineno: int
end_lineno: int
type_name: str
tokens: List = Field(default_factory=list)
properties: Dict = Field(default_factory=dict)
class ClassInfo(BaseModel):
name: str
package: Optional[str] = None
attributes: Dict[str, str] = Field(default_factory=dict)
methods: Dict[str, str] = Field(default_factory=dict)
class RepoParser(BaseModel):
base_directory: Path = Field(default=None)
@ -31,32 +57,24 @@ class RepoParser(BaseModel):
"""Parse a Python file in the repository."""
return ast.parse(file_path.read_text()).body
def extract_class_and_function_info(self, tree, file_path):
def extract_class_and_function_info(self, tree, file_path) -> RepoFileInfo:
"""Extract class, function, and global variable information from the AST."""
file_info = {
"file": str(file_path.relative_to(self.base_directory)),
"classes": [],
"functions": [],
"globals": [],
}
page_info = []
file_info = RepoFileInfo(file=str(file_path.relative_to(self.base_directory)))
for node in tree:
info = RepoParser.node_to_str(node)
page_info.append(info)
file_info.page_info.append(info)
if isinstance(node, ast.ClassDef):
class_methods = [m.name for m in node.body if is_func(m)]
file_info["classes"].append({"name": node.name, "methods": class_methods})
file_info.classes.append({"name": node.name, "methods": class_methods})
elif is_func(node):
file_info["functions"].append(node.name)
file_info.functions.append(node.name)
elif isinstance(node, (ast.Assign, ast.AnnAssign)):
for target in node.targets if isinstance(node, ast.Assign) else [node.target]:
if isinstance(target, ast.Name):
file_info["globals"].append(target.id)
file_info["page_info"] = page_info
file_info.globals.append(target.id)
return file_info
def generate_symbols(self):
def generate_symbols(self) -> List[RepoFileInfo]:
files_classes = []
directory = self.base_directory
@ -93,37 +111,213 @@ class RepoParser(BaseModel):
self.generate_dataframe_structure(output_path)
@staticmethod
def node_to_str(node) -> (int, int, str, str | List):
def _parse_name(n):
if n.asname:
return f"{n.name} as {n.asname}"
return n.name
def node_to_str(node) -> (int, int, str, str | Tuple):
if any_to_str(node) == any_to_str(ast.Expr):
return node.lineno, node.end_lineno, any_to_str(node), RepoParser._parse_expr(node)
return CodeBlockInfo(
lineno=node.lineno,
end_lineno=node.end_lineno,
type_name=any_to_str(node),
tokens=RepoParser._parse_expr(node),
)
mappings = {
any_to_str(ast.Import): lambda x: [_parse_name(n) for n in x.names],
any_to_str(ast.Assign): lambda x: [n.id for n in x.targets],
any_to_str(ast.Import): lambda x: [RepoParser._parse_name(n) for n in x.names],
any_to_str(ast.Assign): RepoParser._parse_assign,
any_to_str(ast.ClassDef): lambda x: x.name,
any_to_str(ast.FunctionDef): lambda x: x.name,
any_to_str(ast.ImportFrom): lambda x: {"module": x.module, "names": [_parse_name(n) for n in x.names]},
any_to_str(ast.If): lambda x: x.test.left.id,
any_to_str(ast.ImportFrom): lambda x: {
"module": x.module,
"names": [RepoParser._parse_name(n) for n in x.names],
},
any_to_str(ast.If): RepoParser._parse_if,
any_to_str(ast.AsyncFunctionDef): lambda x: x.name,
}
func = mappings.get(any_to_str(node))
if func:
return node.lineno, node.end_lineno, any_to_str(node), func(node)
return node.lineno, node.end_lineno, any_to_str(node), None
code_block = CodeBlockInfo(lineno=node.lineno, end_lineno=node.end_lineno, type_name=any_to_str(node))
val = func(node)
if isinstance(val, dict):
code_block.properties = val
elif isinstance(val, list):
code_block.tokens = val
elif isinstance(val, str):
code_block.tokens = [val]
else:
raise NotImplementedError(f"Not implement:{val}")
return code_block
raise NotImplementedError(f"Not implement code block:{node.lineno}, {node.end_lineno}, {any_to_str(node)}")
@staticmethod
def _parse_expr(node) -> (int, int, str, str | List):
if isinstance(node.value, ast.Constant):
return any_to_str(ast.Constant), node.value.value
if isinstance(node.value, ast.Call):
if isinstance(node.value.func, ast.Attribute):
return any_to_str(ast.Call), f"{node.value.func.value.id}.{node.value.func.attr}"
if isinstance(node.value.func, ast.Name):
return any_to_str(ast.Call), node.value.func.id
return any_to_str(node.value), None
def _parse_expr(node) -> List:
funcs = {
any_to_str(ast.Constant): lambda x: [any_to_str(x.value), RepoParser._parse_variable(x.value)],
any_to_str(ast.Call): lambda x: [any_to_str(x.value), RepoParser._parse_variable(x.value.func)],
}
func = funcs.get(any_to_str(node.value))
if func:
return func(node)
raise NotImplementedError(f"Not implement: {node.value}")
@staticmethod
def _parse_name(n):
if n.asname:
return f"{n.name} as {n.asname}"
return n.name
@staticmethod
def _parse_if(n):
tokens = [RepoParser._parse_variable(n.test.left)]
for item in n.test.comparators:
tokens.append(RepoParser._parse_variable(item))
return tokens
@staticmethod
def _parse_variable(node):
funcs = {
any_to_str(ast.Constant): lambda x: x.value,
any_to_str(ast.Name): lambda x: x.id,
any_to_str(ast.Attribute): lambda x: f"{x.value.id}.{x.attr}",
}
func = funcs.get(any_to_str(node))
if not func:
raise NotImplementedError(f"Not implement:{node}")
return func(node)
@staticmethod
def _parse_assign(node):
return [RepoParser._parse_variable(t) for t in node.targets]
async def rebuild_class_views(self, path: str | Path = None):
if not path:
path = self.base_directory
path = Path(path)
if not path.exists():
return
command = f"pyreverse {str(path)} -o dot"
result = subprocess.run(command, shell=True, check=True, cwd=str(path))
if result.returncode != 0:
raise ValueError(f"{result}")
class_view_pathname = path / "classes.dot"
class_views = await self._parse_classes(class_view_pathname)
packages_pathname = path / "packages.dot"
class_views = RepoParser._repair_namespaces(class_views=class_views, path=path)
class_view_pathname.unlink(missing_ok=True)
packages_pathname.unlink(missing_ok=True)
return class_views
async def _parse_classes(self, class_view_pathname):
class_views = []
if not class_view_pathname.exists():
return class_views
async with aiofiles.open(str(class_view_pathname), mode="r") as reader:
lines = await reader.readlines()
for line in lines:
package_name, info = RepoParser._split_class_line(line)
if not package_name:
continue
class_name, members, functions = re.split(r"(?<!\\)\|", info)
class_info = ClassInfo(name=class_name)
class_info.package = package_name
for m in members.split("\n"):
if not m:
continue
member_name = m.split(":", 1)[0].strip() if ":" in m else m.strip()
class_info.attributes[member_name] = m
for f in functions.split("\n"):
if not f:
continue
function_name, _ = f.split("(", 1)
class_info.methods[function_name] = f
class_views.append(class_info)
return class_views
@staticmethod
def _split_class_line(line):
part_splitor = '" ['
if part_splitor not in line:
return None, None
ix = line.find(part_splitor)
class_name = line[0:ix].replace('"', "")
left = line[ix:]
begin_flag = "label=<{"
end_flag = "}>"
if begin_flag not in left or end_flag not in left:
return None, None
bix = left.find(begin_flag)
eix = left.rfind(end_flag)
info = left[bix + len(begin_flag) : eix]
info = re.sub(r"<br[^>]*>", "\n", info)
return class_name, info
@staticmethod
def _create_path_mapping(path: str | Path) -> Dict[str, str]:
mappings = {
str(path).replace("/", "."): str(path),
}
files = []
try:
directory_path = Path(path)
if not directory_path.exists():
return mappings
for file_path in directory_path.iterdir():
if file_path.is_file():
files.append(str(file_path))
else:
subfolder_files = RepoParser._create_path_mapping(path=file_path)
mappings.update(subfolder_files)
except Exception as e:
logger.error(f"Error: {e}")
for f in files:
mappings[str(Path(f).with_suffix("")).replace("/", ".")] = str(f)
return mappings
@staticmethod
def _repair_namespaces(class_views: List[ClassInfo], path: str | Path) -> List[ClassInfo]:
if not class_views:
return []
c = class_views[0]
full_key = str(path).lstrip("/").replace("/", ".")
root_namespace = RepoParser._find_root(full_key, c.package)
root_path = root_namespace.replace(".", "/")
mappings = RepoParser._create_path_mapping(path=path)
new_mappings = {}
ix_root_namespace = len(root_namespace)
ix_root_path = len(root_path)
for k, v in mappings.items():
nk = k[ix_root_namespace:]
nv = v[ix_root_path:]
new_mappings[nk] = nv
for c in class_views:
c.package = RepoParser._repair_ns(c.package, new_mappings)
return class_views
@staticmethod
def _repair_ns(package, mappings):
file_ns = package
while file_ns != "":
if file_ns not in mappings:
ix = file_ns.rfind(".")
file_ns = file_ns[0:ix]
continue
break
internal_ns = package[ix + 1 :]
ns = mappings[file_ns] + ":" + internal_ns.replace(".", ":")
return ns
@staticmethod
def _find_root(full_key, package) -> str:
left = full_key
while left != "":
if left in package:
break
if "." not in left:
break
ix = left.find(".")
left = left[ix + 1 :]
ix = full_key.rfind(left)
return "." + full_key[0:ix]
def is_func(node):

View file

@ -17,8 +17,8 @@ import inspect
import os
import platform
import re
import typing
from typing import List, Tuple, Union
from pathlib import Path
from typing import Callable, List, Tuple, Union
import aiofiles
import loguru
@ -332,7 +332,7 @@ def get_class_name(cls) -> str:
return f"{cls.__module__}.{cls.__name__}"
def any_to_str(val: str | typing.Callable) -> str:
def any_to_str(val: str | Callable) -> str:
"""Return the class name or the class name of the object, or 'val' if it's a string type."""
if isinstance(val, str):
return val
@ -443,3 +443,20 @@ async def aread(file_path: str) -> str:
async with aiofiles.open(str(file_path), mode="r") as reader:
content = await reader.read()
return content
async def read_file_block(filename: str | Path, lineno: int, end_lineno: int):
if not Path(filename).exists():
return ""
lines = []
async with aiofiles.open(str(filename), mode="r") as reader:
ix = 0
while ix < end_lineno:
ix += 1
line = await reader.readline()
if ix < lineno:
continue
if ix > end_lineno:
break
lines.append(line)
return "".join(lines)

View file

@ -10,11 +10,12 @@ from __future__ import annotations
import json
from pathlib import Path
from typing import List
import aiofiles
import networkx
from metagpt.utils.graph_repository import GraphRepository
from metagpt.utils.graph_repository import SPO, GraphRepository
class DiGraphRepository(GraphRepository):
@ -31,6 +32,18 @@ class DiGraphRepository(GraphRepository):
async def update(self, subject: str, predicate: str, object_: str):
pass
async def select(self, subject: str = None, predicate: str = None, object_: str = None) -> List[SPO]:
result = []
for s, o, p in self._repo.edges(data="predicate"):
if subject and subject != s:
continue
if predicate and predicate != p:
continue
if object_ and object_ != o:
continue
result.append(SPO(subject=s, predicate=p, object_=o))
return result
def json(self) -> str:
m = networkx.node_link_data(self._repo)
data = json.dumps(m)
@ -53,10 +66,12 @@ class DiGraphRepository(GraphRepository):
@staticmethod
async def load_from(pathname: str | Path) -> GraphRepository:
name = Path(pathname).with_suffix("").name
root = Path(pathname).parent
pathname = Path(pathname)
name = pathname.with_suffix("").name
root = pathname.parent
graph = DiGraphRepository(name=name, root=root)
await graph.load(pathname=pathname)
if pathname.exists():
await graph.load(pathname=pathname)
return graph
@property

View file

@ -6,18 +6,38 @@
@File : graph_repository.py
@Desc : Superclass for graph repository.
"""
from abc import ABC, abstractmethod
from enum import Enum
from pathlib import Path
from typing import List
from pydantic import BaseModel
from metagpt.repo_parser import ClassInfo, RepoFileInfo
from metagpt.utils.common import concat_namespace
class GraphKeyword(Enum):
class GraphKeyword:
IS = "is"
CLASS = "class"
FUNCTION = "function"
SOURCE_CODE = "source_code"
NULL = "<null>"
GLOBAL_VARIABLE = "global_variable"
CLASS_FUNCTION = "class_function"
CLASS_PROPERTY = "class_property"
HAS_CLASS = "has_class"
HAS_PAGE_INFO = "has_page_info"
HAS_CLASS_VIEW = "has_class_view"
HAS_SEQUENCE_VIEW = "has_sequence_view"
HAS_ARGS_DESC = "has_args_desc"
HAS_TYPE_DESC = "has_type_desc"
class SPO(BaseModel):
subject: str
predicate: str
object_: str
class GraphRepository(ABC):
@ -37,6 +57,94 @@ class GraphRepository(ABC):
async def update(self, subject: str, predicate: str, object_: str):
pass
@abstractmethod
async def select(self, subject: str = None, predicate: str = None, object_: str = None) -> List[SPO]:
pass
@property
def name(self) -> str:
return self._repo_name
@staticmethod
async def update_graph_db_with_file_info(graph_db: "GraphRepository", file_info: RepoFileInfo):
await graph_db.insert(subject=file_info.file, predicate=GraphKeyword.IS, object_=GraphKeyword.SOURCE_CODE)
file_types = {".py": "python", ".js": "javascript"}
file_type = file_types.get(Path(file_info.file).suffix, GraphKeyword.NULL)
await graph_db.insert(subject=file_info.file, predicate=GraphKeyword.IS, object_=file_type)
for c in file_info.classes:
class_name = c.get("name", "")
await graph_db.insert(
subject=file_info.file,
predicate=GraphKeyword.HAS_CLASS,
object_=concat_namespace(file_info.file, class_name),
)
await graph_db.insert(
subject=concat_namespace(file_info.file, class_name),
predicate=GraphKeyword.IS,
object_=GraphKeyword.CLASS,
)
methods = c.get("methods", [])
for fn in methods:
await graph_db.insert(
subject=concat_namespace(file_info.file, class_name, fn),
predicate=GraphKeyword.IS,
object_=GraphKeyword.CLASS_FUNCTION,
)
for f in file_info.functions:
await graph_db.insert(
subject=concat_namespace(file_info.file, f), predicate=GraphKeyword.IS, object_=GraphKeyword.FUNCTION
)
for g in file_info.globals:
await graph_db.insert(
subject=concat_namespace(file_info.file, g),
predicate=GraphKeyword.IS,
object_=GraphKeyword.GLOBAL_VARIABLE,
)
for code_block in file_info.page_info:
if code_block.tokens:
await graph_db.insert(
subject=concat_namespace(file_info.file, *code_block.tokens),
predicate=GraphKeyword.HAS_PAGE_INFO,
object_=code_block.json(ensure_ascii=False),
)
for k, v in code_block.properties.items():
await graph_db.insert(
subject=concat_namespace(file_info.file, k, v),
predicate=GraphKeyword.HAS_PAGE_INFO,
object_=code_block.json(ensure_ascii=False),
)
@staticmethod
async def update_graph_db_with_class_views(graph_db: "GraphRepository", class_views: List[ClassInfo]):
for c in class_views:
filename, class_name = c.package.split(":", 1)
await graph_db.insert(subject=filename, predicate=GraphKeyword.IS, object_=GraphKeyword.SOURCE_CODE)
file_types = {".py": "python", ".js": "javascript"}
file_type = file_types.get(Path(filename).suffix, GraphKeyword.NULL)
await graph_db.insert(subject=filename, predicate=GraphKeyword.IS, object_=file_type)
await graph_db.insert(subject=filename, predicate=GraphKeyword.HAS_CLASS, object_=class_name)
await graph_db.insert(
subject=c.package,
predicate=GraphKeyword.IS,
object_=GraphKeyword.CLASS,
)
for vn, vt in c.attributes.items():
await graph_db.insert(
subject=concat_namespace(c.package, vn),
predicate=GraphKeyword.IS,
object_=GraphKeyword.CLASS_PROPERTY,
)
await graph_db.insert(
subject=concat_namespace(c.package, vn), predicate=GraphKeyword.HAS_TYPE_DESC, object_=vt
)
for fn, desc in c.methods.items():
await graph_db.insert(
subject=concat_namespace(c.package, fn),
predicate=GraphKeyword.IS,
object_=GraphKeyword.CLASS_FUNCTION,
)
await graph_db.insert(
subject=concat_namespace(c.package, fn),
predicate=GraphKeyword.HAS_ARGS_DESC,
object_=desc,
)

View file

@ -57,4 +57,5 @@ socksio~=1.0.0
gitignore-parser==0.1.9
connexion[swagger-ui]
websockets~=12.0
networkx~=3.2.1
networkx~=3.2.1
pylint~=3.0.3

View file

@ -0,0 +1,24 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2023/12/20
@Author : mashenquan
@File : test_rebuild_class_view.py
@Desc : Unit tests for rebuild_class_view.py
"""
from pathlib import Path
import pytest
from metagpt.actions.rebuild_class_view import RebuildClassView
from metagpt.llm import LLM
@pytest.mark.asyncio
async def test_rebuild():
action = RebuildClassView(name="RedBean", context=Path(__file__).parent.parent, llm=LLM())
await action.run()
if __name__ == "__main__":
pytest.main([__file__, "-s"])

View file

View file

@ -14,9 +14,8 @@ from pydantic import BaseModel
from metagpt.const import DEFAULT_WORKSPACE_ROOT
from metagpt.repo_parser import RepoParser
from metagpt.utils.common import concat_namespace
from metagpt.utils.di_graph_repository import DiGraphRepository
from metagpt.utils.graph_repository import GraphKeyword
from metagpt.utils.graph_repository import GraphRepository
@pytest.mark.asyncio
@ -57,23 +56,7 @@ async def test_js_parser():
repo_parser = RepoParser(base_directory=data.path)
symbols = repo_parser.generate_symbols()
for s in symbols:
ns = s.get("file", "")
for c in s.get("classes", []):
await graph.insert(
subject=concat_namespace(ns, c), predicate=GraphKeyword.IS.value, object_=GraphKeyword.CLASS.value
)
for f in s.get("functions", []):
await graph.insert(
subject=concat_namespace(ns, f),
predicate=GraphKeyword.IS.value,
object_=GraphKeyword.FUNCTION.value,
)
for g in s.get("globals", []):
await graph.insert(
subject=concat_namespace(ns, g),
predicate=GraphKeyword.IS.value,
object_=GraphKeyword.GLOBAL_VARIABLE.value,
)
await GraphRepository.update_graph_db(graph_db=graph, file_info=s)
data = graph.json()
assert data
@ -85,35 +68,14 @@ async def test_codes():
graph = DiGraphRepository(name="test", root=path)
symbols = repo_parser.generate_symbols()
for s in symbols:
ns = s.get("file", "")
for c in s.get("classes", []):
class_name = c.get("name", "")
await graph.insert(
subject=ns, predicate=GraphKeyword.HAS_CLASS.value, object_=concat_namespace(ns, class_name)
)
await graph.insert(
subject=concat_namespace(ns, class_name),
predicate=GraphKeyword.IS.value,
object_=GraphKeyword.CLASS.value,
)
methods = c.get("methods", [])
for fn in methods:
await graph.insert(
subject=concat_namespace(ns, class_name, fn),
predicate=GraphKeyword.IS.value,
object_=GraphKeyword.CLASS_FUNCTION.value,
)
for f in s.get("functions", []):
await graph.insert(
subject=concat_namespace(ns, f), predicate=GraphKeyword.IS.value, object_=GraphKeyword.FUNCTION.value
)
for g in s.get("globals", []):
await graph.insert(
subject=concat_namespace(ns, g),
predicate=GraphKeyword.IS.value,
object_=GraphKeyword.GLOBAL_VARIABLE.value,
)
for file_info in symbols:
for code_block in file_info.page_info:
try:
val = code_block.json(ensure_ascii=False)
assert val
except TypeError as e:
assert not e
await GraphRepository.update_graph_db(graph_db=graph, file_info=file_info)
data = graph.json()
assert data
print(data)