diff --git a/metagpt/actions/action_node.py b/metagpt/actions/action_node.py index 6f1215920..4ed8bf22e 100644 --- a/metagpt/actions/action_node.py +++ b/metagpt/actions/action_node.py @@ -39,7 +39,7 @@ SIMPLE_TEMPLATE = """ {constraint} ## action -Fill in the above nodes based on the format example. +Based on the 'context' content, fill in the {node_name} using the 'format example' format above." """ @@ -247,8 +247,13 @@ class ActionNode: # FIXME: json instruction会带来格式问题,如:"Project name": "web_2048 # 项目名称使用下划线", self.instruction = self.compile_instruction(to="markdown", mode=mode) self.example = self.compile_example(to=to, tag="CONTENT", mode=mode) + node_name = "nodes" if template != SIMPLE_TEMPLATE else f'"{list(self.children.keys())[0]}" node' prompt = template.format( - context=context, example=self.example, instruction=self.instruction, constraint=CONSTRAINT + context=context, + example=self.example, + instruction=self.instruction, + constraint=CONSTRAINT, + node_name=node_name, ) return prompt @@ -302,6 +307,7 @@ class ActionNode: mapping = self.get_mapping(mode) class_name = f"{self.key}_AN" + print(prompt) output = await self._aask_v1(prompt, class_name, mapping, format=to) self.content = output.content self.instruct_content = output.instruct_content diff --git a/metagpt/actions/rebuild_class_view.py b/metagpt/actions/rebuild_class_view.py new file mode 100644 index 000000000..6da3e2989 --- /dev/null +++ b/metagpt/actions/rebuild_class_view.py @@ -0,0 +1,68 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2023/12/19 +@Author : mashenquan +@File : rebuild_class_view.py +@Desc : Rebuild class view info +""" +import re +from pathlib import Path + +from metagpt.actions import Action +from metagpt.config import CONFIG +from metagpt.const import CLASS_VIEW_FILE_REPO, GRAPH_REPO_FILE_REPO +from metagpt.repo_parser import RepoParser +from metagpt.utils.di_graph_repository import DiGraphRepository +from metagpt.utils.graph_repository import GraphKeyword, GraphRepository + + +class RebuildClassView(Action): + def __init__(self, name="", context=None, llm=None): + super().__init__(name=name, context=context, llm=llm) + + async def run(self, with_messages=None, format=CONFIG.prompt_format): + graph_repo_pathname = CONFIG.git_repo.workdir / GRAPH_REPO_FILE_REPO / CONFIG.git_repo.workdir.name + graph_db = await DiGraphRepository.load_from(str(graph_repo_pathname.with_suffix(".json"))) + repo_parser = RepoParser(base_directory=self.context) + class_views = await repo_parser.rebuild_class_views(path=Path(self.context)) # use pylint + await GraphRepository.update_graph_db_with_class_views(graph_db, class_views) + symbols = repo_parser.generate_symbols() # use ast + for file_info in symbols: + await GraphRepository.update_graph_db_with_file_info(graph_db, file_info) + await self._create_mermaid_class_view(graph_db=graph_db) + await self._save(graph_db=graph_db) + + async def _create_mermaid_class_view(self, graph_db): + pass + # dataset = await graph_db.select(subject=concat_namespace(filename, class_name), predicate=GraphKeyword.HAS_PAGE_INFO) + # if not dataset: + # logger.warning(f"No page info for {concat_namespace(filename, class_name)}") + # return + # code_block_info = CodeBlockInfo.parse_raw(dataset[0].object_) + # src_code = await read_file_block(filename=Path(self.context) / filename, lineno=code_block_info.lineno, end_lineno=code_block_info.end_lineno) + # code_type = "" + # dataset = await graph_db.select(subject=filename, predicate=GraphKeyword.IS) + # for spo in dataset: + # if spo.object_ in ["javascript", "python"]: + # code_type = spo.object_ + # break + + # try: + # node = await REBUILD_CLASS_VIEW_NODE.fill(context=f"```{code_type}\n{src_code}\n```", llm=self.llm, to=format) + # class_view = node.instruct_content.dict()["Class View"] + # except Exception as e: + # class_view = RepoParser.rebuild_class_view(src_code, code_type) + # await graph_db.insert(subject=concat_namespace(filename, class_name), predicate=GraphKeyword.HAS_CLASS_VIEW, object_=class_view) + # logger.info(f"{concat_namespace(filename, class_name)} {GraphKeyword.HAS_CLASS_VIEW} {class_view}") + + async def _save(self, graph_db): + class_view_file_repo = CONFIG.git_repo.new_file_repository(relative_path=CLASS_VIEW_FILE_REPO) + dataset = await graph_db.select(predicate=GraphKeyword.HAS_CLASS_VIEW) + all_class_view = [] + for spo in dataset: + title = f"---\ntitle: {spo.subject}\n---\n" + filename = re.sub(r"[/:]", "_", spo.subject) + ".mmd" + await class_view_file_repo.save(filename=filename, content=title + spo.object_) + all_class_view.append(spo.object_) + await class_view_file_repo.save(filename="all.mmd", content="\n".join(all_class_view)) diff --git a/metagpt/actions/rebuild_class_view_an.py b/metagpt/actions/rebuild_class_view_an.py new file mode 100644 index 000000000..da32a9b5e --- /dev/null +++ b/metagpt/actions/rebuild_class_view_an.py @@ -0,0 +1,33 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2023/12/19 +@Author : mashenquan +@File : rebuild_class_view_an.py +@Desc : Defines `ActionNode` objects used by rebuild_class_view.py +""" +from metagpt.actions.action_node import ActionNode + +CLASS_SOURCE_CODE_BLOCK = ActionNode( + key="Class View", + expected_type=str, + instruction='Generate the mermaid class diagram corresponding to source code in "context."', + example=""" + classDiagram + class A { + -int x + +int y + -int speed + -int direction + +__init__(x: int, y: int, speed: int, direction: int) + +change_direction(new_direction: int) None + +move() None + } + """, +) + +REBUILD_CLASS_VIEW_NODES = [ + CLASS_SOURCE_CODE_BLOCK, +] + +REBUILD_CLASS_VIEW_NODE = ActionNode.from_children("RebuildClassView", REBUILD_CLASS_VIEW_NODES) diff --git a/metagpt/const.py b/metagpt/const.py index fcb3a2b3e..53f797001 100644 --- a/metagpt/const.py +++ b/metagpt/const.py @@ -99,6 +99,8 @@ CODE_SUMMARIES_FILE_REPO = "docs/code_summaries" CODE_SUMMARIES_PDF_FILE_REPO = "resources/code_summaries" RESOURCES_FILE_REPO = "resources" SD_OUTPUT_FILE_REPO = "resources/SD_Output" +GRAPH_REPO_FILE_REPO = "docs/graph_repo" +CLASS_VIEW_FILE_REPO = "docs/class_views" YAPI_URL = "http://yapi.deepwisdomai.com/" diff --git a/metagpt/repo_parser.py b/metagpt/repo_parser.py index 03cf7be79..ff34257a6 100644 --- a/metagpt/repo_parser.py +++ b/metagpt/repo_parser.py @@ -9,10 +9,13 @@ from __future__ import annotations import ast import json +import re +import subprocess from pathlib import Path from pprint import pformat -from typing import List +from typing import Dict, List, Optional, Tuple +import aiofiles import pandas as pd from pydantic import BaseModel, Field @@ -22,6 +25,29 @@ from metagpt.utils.common import any_to_str from metagpt.utils.exceptions import handle_exception +class RepoFileInfo(BaseModel): + file: str + classes: List = Field(default_factory=list) + functions: List = Field(default_factory=list) + globals: List = Field(default_factory=list) + page_info: List = Field(default_factory=list) + + +class CodeBlockInfo(BaseModel): + lineno: int + end_lineno: int + type_name: str + tokens: List = Field(default_factory=list) + properties: Dict = Field(default_factory=dict) + + +class ClassInfo(BaseModel): + name: str + package: Optional[str] = None + attributes: Dict[str, str] = Field(default_factory=dict) + methods: Dict[str, str] = Field(default_factory=dict) + + class RepoParser(BaseModel): base_directory: Path = Field(default=None) @@ -31,32 +57,24 @@ class RepoParser(BaseModel): """Parse a Python file in the repository.""" return ast.parse(file_path.read_text()).body - def extract_class_and_function_info(self, tree, file_path): + def extract_class_and_function_info(self, tree, file_path) -> RepoFileInfo: """Extract class, function, and global variable information from the AST.""" - file_info = { - "file": str(file_path.relative_to(self.base_directory)), - "classes": [], - "functions": [], - "globals": [], - } - - page_info = [] + file_info = RepoFileInfo(file=str(file_path.relative_to(self.base_directory))) for node in tree: info = RepoParser.node_to_str(node) - page_info.append(info) + file_info.page_info.append(info) if isinstance(node, ast.ClassDef): class_methods = [m.name for m in node.body if is_func(m)] - file_info["classes"].append({"name": node.name, "methods": class_methods}) + file_info.classes.append({"name": node.name, "methods": class_methods}) elif is_func(node): - file_info["functions"].append(node.name) + file_info.functions.append(node.name) elif isinstance(node, (ast.Assign, ast.AnnAssign)): for target in node.targets if isinstance(node, ast.Assign) else [node.target]: if isinstance(target, ast.Name): - file_info["globals"].append(target.id) - file_info["page_info"] = page_info + file_info.globals.append(target.id) return file_info - def generate_symbols(self): + def generate_symbols(self) -> List[RepoFileInfo]: files_classes = [] directory = self.base_directory @@ -93,37 +111,213 @@ class RepoParser(BaseModel): self.generate_dataframe_structure(output_path) @staticmethod - def node_to_str(node) -> (int, int, str, str | List): - def _parse_name(n): - if n.asname: - return f"{n.name} as {n.asname}" - return n.name - + def node_to_str(node) -> (int, int, str, str | Tuple): if any_to_str(node) == any_to_str(ast.Expr): - return node.lineno, node.end_lineno, any_to_str(node), RepoParser._parse_expr(node) + return CodeBlockInfo( + lineno=node.lineno, + end_lineno=node.end_lineno, + type_name=any_to_str(node), + tokens=RepoParser._parse_expr(node), + ) mappings = { - any_to_str(ast.Import): lambda x: [_parse_name(n) for n in x.names], - any_to_str(ast.Assign): lambda x: [n.id for n in x.targets], + any_to_str(ast.Import): lambda x: [RepoParser._parse_name(n) for n in x.names], + any_to_str(ast.Assign): RepoParser._parse_assign, any_to_str(ast.ClassDef): lambda x: x.name, any_to_str(ast.FunctionDef): lambda x: x.name, - any_to_str(ast.ImportFrom): lambda x: {"module": x.module, "names": [_parse_name(n) for n in x.names]}, - any_to_str(ast.If): lambda x: x.test.left.id, + any_to_str(ast.ImportFrom): lambda x: { + "module": x.module, + "names": [RepoParser._parse_name(n) for n in x.names], + }, + any_to_str(ast.If): RepoParser._parse_if, + any_to_str(ast.AsyncFunctionDef): lambda x: x.name, } func = mappings.get(any_to_str(node)) if func: - return node.lineno, node.end_lineno, any_to_str(node), func(node) - return node.lineno, node.end_lineno, any_to_str(node), None + code_block = CodeBlockInfo(lineno=node.lineno, end_lineno=node.end_lineno, type_name=any_to_str(node)) + val = func(node) + if isinstance(val, dict): + code_block.properties = val + elif isinstance(val, list): + code_block.tokens = val + elif isinstance(val, str): + code_block.tokens = [val] + else: + raise NotImplementedError(f"Not implement:{val}") + return code_block + raise NotImplementedError(f"Not implement code block:{node.lineno}, {node.end_lineno}, {any_to_str(node)}") @staticmethod - def _parse_expr(node) -> (int, int, str, str | List): - if isinstance(node.value, ast.Constant): - return any_to_str(ast.Constant), node.value.value - if isinstance(node.value, ast.Call): - if isinstance(node.value.func, ast.Attribute): - return any_to_str(ast.Call), f"{node.value.func.value.id}.{node.value.func.attr}" - if isinstance(node.value.func, ast.Name): - return any_to_str(ast.Call), node.value.func.id - return any_to_str(node.value), None + def _parse_expr(node) -> List: + funcs = { + any_to_str(ast.Constant): lambda x: [any_to_str(x.value), RepoParser._parse_variable(x.value)], + any_to_str(ast.Call): lambda x: [any_to_str(x.value), RepoParser._parse_variable(x.value.func)], + } + func = funcs.get(any_to_str(node.value)) + if func: + return func(node) + raise NotImplementedError(f"Not implement: {node.value}") + + @staticmethod + def _parse_name(n): + if n.asname: + return f"{n.name} as {n.asname}" + return n.name + + @staticmethod + def _parse_if(n): + tokens = [RepoParser._parse_variable(n.test.left)] + for item in n.test.comparators: + tokens.append(RepoParser._parse_variable(item)) + return tokens + + @staticmethod + def _parse_variable(node): + funcs = { + any_to_str(ast.Constant): lambda x: x.value, + any_to_str(ast.Name): lambda x: x.id, + any_to_str(ast.Attribute): lambda x: f"{x.value.id}.{x.attr}", + } + func = funcs.get(any_to_str(node)) + if not func: + raise NotImplementedError(f"Not implement:{node}") + return func(node) + + @staticmethod + def _parse_assign(node): + return [RepoParser._parse_variable(t) for t in node.targets] + + async def rebuild_class_views(self, path: str | Path = None): + if not path: + path = self.base_directory + path = Path(path) + if not path.exists(): + return + command = f"pyreverse {str(path)} -o dot" + result = subprocess.run(command, shell=True, check=True, cwd=str(path)) + if result.returncode != 0: + raise ValueError(f"{result}") + class_view_pathname = path / "classes.dot" + class_views = await self._parse_classes(class_view_pathname) + packages_pathname = path / "packages.dot" + class_views = RepoParser._repair_namespaces(class_views=class_views, path=path) + class_view_pathname.unlink(missing_ok=True) + packages_pathname.unlink(missing_ok=True) + return class_views + + async def _parse_classes(self, class_view_pathname): + class_views = [] + if not class_view_pathname.exists(): + return class_views + async with aiofiles.open(str(class_view_pathname), mode="r") as reader: + lines = await reader.readlines() + for line in lines: + package_name, info = RepoParser._split_class_line(line) + if not package_name: + continue + class_name, members, functions = re.split(r"(?" + if begin_flag not in left or end_flag not in left: + return None, None + bix = left.find(begin_flag) + eix = left.rfind(end_flag) + info = left[bix + len(begin_flag) : eix] + info = re.sub(r"]*>", "\n", info) + return class_name, info + + @staticmethod + def _create_path_mapping(path: str | Path) -> Dict[str, str]: + mappings = { + str(path).replace("/", "."): str(path), + } + files = [] + try: + directory_path = Path(path) + if not directory_path.exists(): + return mappings + for file_path in directory_path.iterdir(): + if file_path.is_file(): + files.append(str(file_path)) + else: + subfolder_files = RepoParser._create_path_mapping(path=file_path) + mappings.update(subfolder_files) + except Exception as e: + logger.error(f"Error: {e}") + for f in files: + mappings[str(Path(f).with_suffix("")).replace("/", ".")] = str(f) + + return mappings + + @staticmethod + def _repair_namespaces(class_views: List[ClassInfo], path: str | Path) -> List[ClassInfo]: + if not class_views: + return [] + c = class_views[0] + full_key = str(path).lstrip("/").replace("/", ".") + root_namespace = RepoParser._find_root(full_key, c.package) + root_path = root_namespace.replace(".", "/") + + mappings = RepoParser._create_path_mapping(path=path) + new_mappings = {} + ix_root_namespace = len(root_namespace) + ix_root_path = len(root_path) + for k, v in mappings.items(): + nk = k[ix_root_namespace:] + nv = v[ix_root_path:] + new_mappings[nk] = nv + + for c in class_views: + c.package = RepoParser._repair_ns(c.package, new_mappings) + return class_views + + @staticmethod + def _repair_ns(package, mappings): + file_ns = package + while file_ns != "": + if file_ns not in mappings: + ix = file_ns.rfind(".") + file_ns = file_ns[0:ix] + continue + break + internal_ns = package[ix + 1 :] + ns = mappings[file_ns] + ":" + internal_ns.replace(".", ":") + return ns + + @staticmethod + def _find_root(full_key, package) -> str: + left = full_key + while left != "": + if left in package: + break + if "." not in left: + break + ix = left.find(".") + left = left[ix + 1 :] + ix = full_key.rfind(left) + return "." + full_key[0:ix] def is_func(node): diff --git a/metagpt/utils/common.py b/metagpt/utils/common.py index 8fa729556..a5d2100cc 100644 --- a/metagpt/utils/common.py +++ b/metagpt/utils/common.py @@ -17,8 +17,8 @@ import inspect import os import platform import re -import typing -from typing import List, Tuple, Union +from pathlib import Path +from typing import Callable, List, Tuple, Union import aiofiles import loguru @@ -332,7 +332,7 @@ def get_class_name(cls) -> str: return f"{cls.__module__}.{cls.__name__}" -def any_to_str(val: str | typing.Callable) -> str: +def any_to_str(val: str | Callable) -> str: """Return the class name or the class name of the object, or 'val' if it's a string type.""" if isinstance(val, str): return val @@ -443,3 +443,20 @@ async def aread(file_path: str) -> str: async with aiofiles.open(str(file_path), mode="r") as reader: content = await reader.read() return content + + +async def read_file_block(filename: str | Path, lineno: int, end_lineno: int): + if not Path(filename).exists(): + return "" + lines = [] + async with aiofiles.open(str(filename), mode="r") as reader: + ix = 0 + while ix < end_lineno: + ix += 1 + line = await reader.readline() + if ix < lineno: + continue + if ix > end_lineno: + break + lines.append(line) + return "".join(lines) diff --git a/metagpt/utils/di_graph_repository.py b/metagpt/utils/di_graph_repository.py index 9bbd38d5f..08f4327fa 100644 --- a/metagpt/utils/di_graph_repository.py +++ b/metagpt/utils/di_graph_repository.py @@ -10,11 +10,12 @@ from __future__ import annotations import json from pathlib import Path +from typing import List import aiofiles import networkx -from metagpt.utils.graph_repository import GraphRepository +from metagpt.utils.graph_repository import SPO, GraphRepository class DiGraphRepository(GraphRepository): @@ -31,6 +32,18 @@ class DiGraphRepository(GraphRepository): async def update(self, subject: str, predicate: str, object_: str): pass + async def select(self, subject: str = None, predicate: str = None, object_: str = None) -> List[SPO]: + result = [] + for s, o, p in self._repo.edges(data="predicate"): + if subject and subject != s: + continue + if predicate and predicate != p: + continue + if object_ and object_ != o: + continue + result.append(SPO(subject=s, predicate=p, object_=o)) + return result + def json(self) -> str: m = networkx.node_link_data(self._repo) data = json.dumps(m) @@ -53,10 +66,12 @@ class DiGraphRepository(GraphRepository): @staticmethod async def load_from(pathname: str | Path) -> GraphRepository: - name = Path(pathname).with_suffix("").name - root = Path(pathname).parent + pathname = Path(pathname) + name = pathname.with_suffix("").name + root = pathname.parent graph = DiGraphRepository(name=name, root=root) - await graph.load(pathname=pathname) + if pathname.exists(): + await graph.load(pathname=pathname) return graph @property diff --git a/metagpt/utils/graph_repository.py b/metagpt/utils/graph_repository.py index 600575b4e..37da3dee4 100644 --- a/metagpt/utils/graph_repository.py +++ b/metagpt/utils/graph_repository.py @@ -6,18 +6,38 @@ @File : graph_repository.py @Desc : Superclass for graph repository. """ + from abc import ABC, abstractmethod -from enum import Enum +from pathlib import Path +from typing import List + +from pydantic import BaseModel + +from metagpt.repo_parser import ClassInfo, RepoFileInfo +from metagpt.utils.common import concat_namespace -class GraphKeyword(Enum): +class GraphKeyword: IS = "is" CLASS = "class" FUNCTION = "function" + SOURCE_CODE = "source_code" + NULL = "" GLOBAL_VARIABLE = "global_variable" CLASS_FUNCTION = "class_function" CLASS_PROPERTY = "class_property" HAS_CLASS = "has_class" + HAS_PAGE_INFO = "has_page_info" + HAS_CLASS_VIEW = "has_class_view" + HAS_SEQUENCE_VIEW = "has_sequence_view" + HAS_ARGS_DESC = "has_args_desc" + HAS_TYPE_DESC = "has_type_desc" + + +class SPO(BaseModel): + subject: str + predicate: str + object_: str class GraphRepository(ABC): @@ -37,6 +57,94 @@ class GraphRepository(ABC): async def update(self, subject: str, predicate: str, object_: str): pass + @abstractmethod + async def select(self, subject: str = None, predicate: str = None, object_: str = None) -> List[SPO]: + pass + @property def name(self) -> str: return self._repo_name + + @staticmethod + async def update_graph_db_with_file_info(graph_db: "GraphRepository", file_info: RepoFileInfo): + await graph_db.insert(subject=file_info.file, predicate=GraphKeyword.IS, object_=GraphKeyword.SOURCE_CODE) + file_types = {".py": "python", ".js": "javascript"} + file_type = file_types.get(Path(file_info.file).suffix, GraphKeyword.NULL) + await graph_db.insert(subject=file_info.file, predicate=GraphKeyword.IS, object_=file_type) + for c in file_info.classes: + class_name = c.get("name", "") + await graph_db.insert( + subject=file_info.file, + predicate=GraphKeyword.HAS_CLASS, + object_=concat_namespace(file_info.file, class_name), + ) + await graph_db.insert( + subject=concat_namespace(file_info.file, class_name), + predicate=GraphKeyword.IS, + object_=GraphKeyword.CLASS, + ) + methods = c.get("methods", []) + for fn in methods: + await graph_db.insert( + subject=concat_namespace(file_info.file, class_name, fn), + predicate=GraphKeyword.IS, + object_=GraphKeyword.CLASS_FUNCTION, + ) + for f in file_info.functions: + await graph_db.insert( + subject=concat_namespace(file_info.file, f), predicate=GraphKeyword.IS, object_=GraphKeyword.FUNCTION + ) + for g in file_info.globals: + await graph_db.insert( + subject=concat_namespace(file_info.file, g), + predicate=GraphKeyword.IS, + object_=GraphKeyword.GLOBAL_VARIABLE, + ) + for code_block in file_info.page_info: + if code_block.tokens: + await graph_db.insert( + subject=concat_namespace(file_info.file, *code_block.tokens), + predicate=GraphKeyword.HAS_PAGE_INFO, + object_=code_block.json(ensure_ascii=False), + ) + for k, v in code_block.properties.items(): + await graph_db.insert( + subject=concat_namespace(file_info.file, k, v), + predicate=GraphKeyword.HAS_PAGE_INFO, + object_=code_block.json(ensure_ascii=False), + ) + + @staticmethod + async def update_graph_db_with_class_views(graph_db: "GraphRepository", class_views: List[ClassInfo]): + for c in class_views: + filename, class_name = c.package.split(":", 1) + await graph_db.insert(subject=filename, predicate=GraphKeyword.IS, object_=GraphKeyword.SOURCE_CODE) + file_types = {".py": "python", ".js": "javascript"} + file_type = file_types.get(Path(filename).suffix, GraphKeyword.NULL) + await graph_db.insert(subject=filename, predicate=GraphKeyword.IS, object_=file_type) + await graph_db.insert(subject=filename, predicate=GraphKeyword.HAS_CLASS, object_=class_name) + await graph_db.insert( + subject=c.package, + predicate=GraphKeyword.IS, + object_=GraphKeyword.CLASS, + ) + for vn, vt in c.attributes.items(): + await graph_db.insert( + subject=concat_namespace(c.package, vn), + predicate=GraphKeyword.IS, + object_=GraphKeyword.CLASS_PROPERTY, + ) + await graph_db.insert( + subject=concat_namespace(c.package, vn), predicate=GraphKeyword.HAS_TYPE_DESC, object_=vt + ) + for fn, desc in c.methods.items(): + await graph_db.insert( + subject=concat_namespace(c.package, fn), + predicate=GraphKeyword.IS, + object_=GraphKeyword.CLASS_FUNCTION, + ) + await graph_db.insert( + subject=concat_namespace(c.package, fn), + predicate=GraphKeyword.HAS_ARGS_DESC, + object_=desc, + ) diff --git a/requirements.txt b/requirements.txt index 4310aec6c..c4e674569 100644 --- a/requirements.txt +++ b/requirements.txt @@ -57,4 +57,5 @@ socksio~=1.0.0 gitignore-parser==0.1.9 connexion[swagger-ui] websockets~=12.0 -networkx~=3.2.1 \ No newline at end of file +networkx~=3.2.1 +pylint~=3.0.3 \ No newline at end of file diff --git a/tests/metagpt/actions/test_rebuild_class_view.py b/tests/metagpt/actions/test_rebuild_class_view.py new file mode 100644 index 000000000..955c6ae3b --- /dev/null +++ b/tests/metagpt/actions/test_rebuild_class_view.py @@ -0,0 +1,24 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2023/12/20 +@Author : mashenquan +@File : test_rebuild_class_view.py +@Desc : Unit tests for rebuild_class_view.py +""" +from pathlib import Path + +import pytest + +from metagpt.actions.rebuild_class_view import RebuildClassView +from metagpt.llm import LLM + + +@pytest.mark.asyncio +async def test_rebuild(): + action = RebuildClassView(name="RedBean", context=Path(__file__).parent.parent, llm=LLM()) + await action.run() + + +if __name__ == "__main__": + pytest.main([__file__, "-s"]) diff --git a/tests/metagpt/test_repo_parser.py b/tests/metagpt/test_repo_parser.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/metagpt/utils/test_di_graph_repository.py b/tests/metagpt/utils/test_di_graph_repository.py index ec2cb4d01..0a8011e51 100644 --- a/tests/metagpt/utils/test_di_graph_repository.py +++ b/tests/metagpt/utils/test_di_graph_repository.py @@ -14,9 +14,8 @@ from pydantic import BaseModel from metagpt.const import DEFAULT_WORKSPACE_ROOT from metagpt.repo_parser import RepoParser -from metagpt.utils.common import concat_namespace from metagpt.utils.di_graph_repository import DiGraphRepository -from metagpt.utils.graph_repository import GraphKeyword +from metagpt.utils.graph_repository import GraphRepository @pytest.mark.asyncio @@ -57,23 +56,7 @@ async def test_js_parser(): repo_parser = RepoParser(base_directory=data.path) symbols = repo_parser.generate_symbols() for s in symbols: - ns = s.get("file", "") - for c in s.get("classes", []): - await graph.insert( - subject=concat_namespace(ns, c), predicate=GraphKeyword.IS.value, object_=GraphKeyword.CLASS.value - ) - for f in s.get("functions", []): - await graph.insert( - subject=concat_namespace(ns, f), - predicate=GraphKeyword.IS.value, - object_=GraphKeyword.FUNCTION.value, - ) - for g in s.get("globals", []): - await graph.insert( - subject=concat_namespace(ns, g), - predicate=GraphKeyword.IS.value, - object_=GraphKeyword.GLOBAL_VARIABLE.value, - ) + await GraphRepository.update_graph_db(graph_db=graph, file_info=s) data = graph.json() assert data @@ -85,35 +68,14 @@ async def test_codes(): graph = DiGraphRepository(name="test", root=path) symbols = repo_parser.generate_symbols() - for s in symbols: - ns = s.get("file", "") - for c in s.get("classes", []): - class_name = c.get("name", "") - await graph.insert( - subject=ns, predicate=GraphKeyword.HAS_CLASS.value, object_=concat_namespace(ns, class_name) - ) - await graph.insert( - subject=concat_namespace(ns, class_name), - predicate=GraphKeyword.IS.value, - object_=GraphKeyword.CLASS.value, - ) - methods = c.get("methods", []) - for fn in methods: - await graph.insert( - subject=concat_namespace(ns, class_name, fn), - predicate=GraphKeyword.IS.value, - object_=GraphKeyword.CLASS_FUNCTION.value, - ) - for f in s.get("functions", []): - await graph.insert( - subject=concat_namespace(ns, f), predicate=GraphKeyword.IS.value, object_=GraphKeyword.FUNCTION.value - ) - for g in s.get("globals", []): - await graph.insert( - subject=concat_namespace(ns, g), - predicate=GraphKeyword.IS.value, - object_=GraphKeyword.GLOBAL_VARIABLE.value, - ) + for file_info in symbols: + for code_block in file_info.page_info: + try: + val = code_block.json(ensure_ascii=False) + assert val + except TypeError as e: + assert not e + await GraphRepository.update_graph_db(graph_db=graph, file_info=file_info) data = graph.json() assert data print(data)