diff --git a/metagpt/memory/brain_memory.py b/metagpt/memory/brain_memory.py index 034bcfa56..8aa3be2b6 100644 --- a/metagpt/memory/brain_memory.py +++ b/metagpt/memory/brain_memory.py @@ -4,7 +4,7 @@ @Time : 2023/8/18 @Author : mashenquan @File : brain_memory.py -@Desc : Support memory for multiple tasks and multiple mainlines. +@Desc : Support memory for multiple tasks and multiple mainlines. Obsoleted by `utils/*_repository.py`. @Modified By: mashenquan, 2023/9/4. + redis memory cache. """ import json diff --git a/metagpt/repo_parser.py b/metagpt/repo_parser.py index b84dbab9a..65c2959e4 100644 --- a/metagpt/repo_parser.py +++ b/metagpt/repo_parser.py @@ -51,7 +51,11 @@ class RepoParser(BaseModel): def generate_symbols(self): files_classes = [] directory = self.base_directory - for path in directory.rglob("*.py"): + matching_files = [] + extensions = ["*.py", "*.js"] + for ext in extensions: + matching_files += directory.rglob(ext) + for path in matching_files: tree = self.parse_file(path) file_info = self.extract_class_and_function_info(tree, path) files_classes.append(file_info) diff --git a/metagpt/utils/common.py b/metagpt/utils/common.py index 2a3d22698..575c77b5e 100644 --- a/metagpt/utils/common.py +++ b/metagpt/utils/common.py @@ -393,3 +393,7 @@ def format_value(value): for k, v in merged_opts.items(): value = value.replace("{" + f"{k}" + "}", str(v)) return value + + +def concat_namespace(*args) -> str: + return ":".join(str(value) for value in args) diff --git a/metagpt/utils/di_graph_repository.py b/metagpt/utils/di_graph_repository.py new file mode 100644 index 000000000..9bbd38d5f --- /dev/null +++ b/metagpt/utils/di_graph_repository.py @@ -0,0 +1,69 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2023/12/19 +@Author : mashenquan +@File : di_graph_repository.py +@Desc : Graph repository based on DiGraph +""" +from __future__ import annotations + +import json +from pathlib import Path + +import aiofiles +import networkx + +from metagpt.utils.graph_repository import GraphRepository + + +class DiGraphRepository(GraphRepository): + def __init__(self, name: str, **kwargs): + super().__init__(name=name, **kwargs) + self._repo = networkx.DiGraph() + + async def insert(self, subject: str, predicate: str, object_: str): + self._repo.add_edge(subject, object_, predicate=predicate) + + async def upsert(self, subject: str, predicate: str, object_: str): + pass + + async def update(self, subject: str, predicate: str, object_: str): + pass + + def json(self) -> str: + m = networkx.node_link_data(self._repo) + data = json.dumps(m) + return data + + async def save(self, path: str | Path = None): + data = self.json() + path = path or self._kwargs.get("root") + if not path.exists(): + path.mkdir(parents=True, exist_ok=True) + pathname = Path(path) / self.name + async with aiofiles.open(str(pathname.with_suffix(".json")), mode="w", encoding="utf-8") as writer: + await writer.write(data) + + async def load(self, pathname: str | Path): + async with aiofiles.open(str(pathname), mode="r", encoding="utf-8") as reader: + data = await reader.read(-1) + m = json.loads(data) + self._repo = networkx.node_link_graph(m) + + @staticmethod + async def load_from(pathname: str | Path) -> GraphRepository: + name = Path(pathname).with_suffix("").name + root = Path(pathname).parent + graph = DiGraphRepository(name=name, root=root) + await graph.load(pathname=pathname) + return graph + + @property + def root(self) -> str: + return self._kwargs.get("root") + + @property + def pathname(self) -> Path: + p = Path(self.root) / self.name + return p.with_suffix(".json") diff --git a/metagpt/utils/graph_repository.py b/metagpt/utils/graph_repository.py new file mode 100644 index 000000000..600575b4e --- /dev/null +++ b/metagpt/utils/graph_repository.py @@ -0,0 +1,42 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2023/12/19 +@Author : mashenquan +@File : graph_repository.py +@Desc : Superclass for graph repository. +""" +from abc import ABC, abstractmethod +from enum import Enum + + +class GraphKeyword(Enum): + IS = "is" + CLASS = "class" + FUNCTION = "function" + GLOBAL_VARIABLE = "global_variable" + CLASS_FUNCTION = "class_function" + CLASS_PROPERTY = "class_property" + HAS_CLASS = "has_class" + + +class GraphRepository(ABC): + def __init__(self, name: str, **kwargs): + self._repo_name = name + self._kwargs = kwargs + + @abstractmethod + async def insert(self, subject: str, predicate: str, object_: str): + pass + + @abstractmethod + async def upsert(self, subject: str, predicate: str, object_: str): + pass + + @abstractmethod + async def update(self, subject: str, predicate: str, object_: str): + pass + + @property + def name(self) -> str: + return self._repo_name diff --git a/requirements.txt b/requirements.txt index d2a4e5bb4..4310aec6c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -56,4 +56,5 @@ zhipuai==1.0.7 socksio~=1.0.0 gitignore-parser==0.1.9 connexion[swagger-ui] -websockets~=12.0 \ No newline at end of file +websockets~=12.0 +networkx~=3.2.1 \ No newline at end of file diff --git a/tests/metagpt/utils/test_di_graph_repository.py b/tests/metagpt/utils/test_di_graph_repository.py new file mode 100644 index 000000000..7a9e58d1c --- /dev/null +++ b/tests/metagpt/utils/test_di_graph_repository.py @@ -0,0 +1,121 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2023/12/19 +@Author : mashenquan +@File : test_di_graph_repository.py +@Desc : Unit tests for di_graph_repository.py +""" + +from pathlib import Path + +import pytest +from pydantic import BaseModel + +from metagpt.const import DEFAULT_WORKSPACE_ROOT +from metagpt.repo_parser import RepoParser +from metagpt.utils.common import concat_namespace +from metagpt.utils.di_graph_repository import DiGraphRepository +from metagpt.utils.graph_repository import GraphKeyword + + +@pytest.mark.asyncio +async def test_di_graph_repository(): + class Input(BaseModel): + s: str + p: str + o: str + + inputs = [ + {"s": "main.py:Game:draw", "p": "method:hasDescription", "o": "Draw image"}, + {"s": "main.py:Game:draw", "p": "method:hasDescription", "o": "Show image"}, + ] + path = Path(__file__).parent + graph = DiGraphRepository(name="test", root=path) + for i in inputs: + data = Input(**i) + await graph.insert(subject=data.s, predicate=data.p, object_=data.o) + v = graph.json() + assert v + await graph.save() + assert graph.pathname.exists() + graph.pathname.unlink() + + +async def test_js_parser(): + class Input(BaseModel): + path: str + + inputs = [ + {"path": str(Path(__file__).parent / "../../data/code")}, + ] + path = Path(__file__).parent + graph = DiGraphRepository(name="test", root=path) + for i in inputs: + data = Input(**i) + repo_parser = RepoParser(base_directory=data.path) + symbols = repo_parser.generate_symbols() + for s in symbols: + ns = s.get("file", "") + for c in s.get("classes", []): + await graph.insert( + subject=concat_namespace(ns, c), predicate=GraphKeyword.IS.value, object_=GraphKeyword.CLASS.value + ) + for f in s.get("functions", []): + await graph.insert( + subject=concat_namespace(ns, f), + predicate=GraphKeyword.IS.value, + object_=GraphKeyword.FUNCTION.value, + ) + for g in s.get("globals", []): + await graph.insert( + subject=concat_namespace(ns, g), + predicate=GraphKeyword.IS.value, + object_=GraphKeyword.GLOBAL_VARIABLE.value, + ) + data = graph.json() + assert data + + +async def test_codes(): + path = DEFAULT_WORKSPACE_ROOT / "snake_game" + repo_parser = RepoParser(base_directory=path) + + graph = DiGraphRepository(name="test", root=path) + symbols = repo_parser.generate_symbols() + for s in symbols: + ns = s.get("file", "") + for c in s.get("classes", []): + class_name = c.get("name", "") + await graph.insert( + subject=ns, predicate=GraphKeyword.HAS_CLASS.value, object_=concat_namespace(ns, class_name) + ) + await graph.insert( + subject=concat_namespace(ns, class_name), + predicate=GraphKeyword.IS.value, + object_=GraphKeyword.CLASS.value, + ) + methods = c.get("methods", []) + for fn in methods: + await graph.insert( + subject=concat_namespace(ns, class_name, fn), + predicate=GraphKeyword.IS.value, + object_=GraphKeyword.CLASS_FUNCTION.value, + ) + for f in s.get("functions", []): + await graph.insert( + subject=concat_namespace(ns, f), predicate=GraphKeyword.IS.value, object_=GraphKeyword.FUNCTION.value + ) + for g in s.get("globals", []): + await graph.insert( + subject=concat_namespace(ns, g), + predicate=GraphKeyword.IS.value, + object_=GraphKeyword.GLOBAL_VARIABLE.value, + ) + data = graph.json() + assert data + print(data) + + +if __name__ == "__main__": + pytest.main([__file__, "-s"])