mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-05-24 14:15:17 +02:00
feat: add GraphRepository
This commit is contained in:
parent
76b2dfecdc
commit
3a35c0a0cd
7 changed files with 244 additions and 3 deletions
|
|
@ -4,7 +4,7 @@
|
|||
@Time : 2023/8/18
|
||||
@Author : mashenquan
|
||||
@File : brain_memory.py
|
||||
@Desc : Support memory for multiple tasks and multiple mainlines.
|
||||
@Desc : Support memory for multiple tasks and multiple mainlines. Obsoleted by `utils/*_repository.py`.
|
||||
@Modified By: mashenquan, 2023/9/4. + redis memory cache.
|
||||
"""
|
||||
import json
|
||||
|
|
|
|||
|
|
@ -51,7 +51,11 @@ class RepoParser(BaseModel):
|
|||
def generate_symbols(self):
|
||||
files_classes = []
|
||||
directory = self.base_directory
|
||||
for path in directory.rglob("*.py"):
|
||||
matching_files = []
|
||||
extensions = ["*.py", "*.js"]
|
||||
for ext in extensions:
|
||||
matching_files += directory.rglob(ext)
|
||||
for path in matching_files:
|
||||
tree = self.parse_file(path)
|
||||
file_info = self.extract_class_and_function_info(tree, path)
|
||||
files_classes.append(file_info)
|
||||
|
|
|
|||
|
|
@ -393,3 +393,7 @@ def format_value(value):
|
|||
for k, v in merged_opts.items():
|
||||
value = value.replace("{" + f"{k}" + "}", str(v))
|
||||
return value
|
||||
|
||||
|
||||
def concat_namespace(*args) -> str:
|
||||
return ":".join(str(value) for value in args)
|
||||
|
|
|
|||
69
metagpt/utils/di_graph_repository.py
Normal file
69
metagpt/utils/di_graph_repository.py
Normal file
|
|
@ -0,0 +1,69 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/12/19
|
||||
@Author : mashenquan
|
||||
@File : di_graph_repository.py
|
||||
@Desc : Graph repository based on DiGraph
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import aiofiles
|
||||
import networkx
|
||||
|
||||
from metagpt.utils.graph_repository import GraphRepository
|
||||
|
||||
|
||||
class DiGraphRepository(GraphRepository):
|
||||
def __init__(self, name: str, **kwargs):
|
||||
super().__init__(name=name, **kwargs)
|
||||
self._repo = networkx.DiGraph()
|
||||
|
||||
async def insert(self, subject: str, predicate: str, object_: str):
|
||||
self._repo.add_edge(subject, object_, predicate=predicate)
|
||||
|
||||
async def upsert(self, subject: str, predicate: str, object_: str):
|
||||
pass
|
||||
|
||||
async def update(self, subject: str, predicate: str, object_: str):
|
||||
pass
|
||||
|
||||
def json(self) -> str:
|
||||
m = networkx.node_link_data(self._repo)
|
||||
data = json.dumps(m)
|
||||
return data
|
||||
|
||||
async def save(self, path: str | Path = None):
|
||||
data = self.json()
|
||||
path = path or self._kwargs.get("root")
|
||||
if not path.exists():
|
||||
path.mkdir(parents=True, exist_ok=True)
|
||||
pathname = Path(path) / self.name
|
||||
async with aiofiles.open(str(pathname.with_suffix(".json")), mode="w", encoding="utf-8") as writer:
|
||||
await writer.write(data)
|
||||
|
||||
async def load(self, pathname: str | Path):
|
||||
async with aiofiles.open(str(pathname), mode="r", encoding="utf-8") as reader:
|
||||
data = await reader.read(-1)
|
||||
m = json.loads(data)
|
||||
self._repo = networkx.node_link_graph(m)
|
||||
|
||||
@staticmethod
|
||||
async def load_from(pathname: str | Path) -> GraphRepository:
|
||||
name = Path(pathname).with_suffix("").name
|
||||
root = Path(pathname).parent
|
||||
graph = DiGraphRepository(name=name, root=root)
|
||||
await graph.load(pathname=pathname)
|
||||
return graph
|
||||
|
||||
@property
|
||||
def root(self) -> str:
|
||||
return self._kwargs.get("root")
|
||||
|
||||
@property
|
||||
def pathname(self) -> Path:
|
||||
p = Path(self.root) / self.name
|
||||
return p.with_suffix(".json")
|
||||
42
metagpt/utils/graph_repository.py
Normal file
42
metagpt/utils/graph_repository.py
Normal file
|
|
@ -0,0 +1,42 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/12/19
|
||||
@Author : mashenquan
|
||||
@File : graph_repository.py
|
||||
@Desc : Superclass for graph repository.
|
||||
"""
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class GraphKeyword(Enum):
|
||||
IS = "is"
|
||||
CLASS = "class"
|
||||
FUNCTION = "function"
|
||||
GLOBAL_VARIABLE = "global_variable"
|
||||
CLASS_FUNCTION = "class_function"
|
||||
CLASS_PROPERTY = "class_property"
|
||||
HAS_CLASS = "has_class"
|
||||
|
||||
|
||||
class GraphRepository(ABC):
|
||||
def __init__(self, name: str, **kwargs):
|
||||
self._repo_name = name
|
||||
self._kwargs = kwargs
|
||||
|
||||
@abstractmethod
|
||||
async def insert(self, subject: str, predicate: str, object_: str):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def upsert(self, subject: str, predicate: str, object_: str):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def update(self, subject: str, predicate: str, object_: str):
|
||||
pass
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self._repo_name
|
||||
|
|
@ -56,4 +56,5 @@ zhipuai==1.0.7
|
|||
socksio~=1.0.0
|
||||
gitignore-parser==0.1.9
|
||||
connexion[swagger-ui]
|
||||
websockets~=12.0
|
||||
websockets~=12.0
|
||||
networkx~=3.2.1
|
||||
121
tests/metagpt/utils/test_di_graph_repository.py
Normal file
121
tests/metagpt/utils/test_di_graph_repository.py
Normal file
|
|
@ -0,0 +1,121 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/12/19
|
||||
@Author : mashenquan
|
||||
@File : test_di_graph_repository.py
|
||||
@Desc : Unit tests for di_graph_repository.py
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
|
||||
from metagpt.const import DEFAULT_WORKSPACE_ROOT
|
||||
from metagpt.repo_parser import RepoParser
|
||||
from metagpt.utils.common import concat_namespace
|
||||
from metagpt.utils.di_graph_repository import DiGraphRepository
|
||||
from metagpt.utils.graph_repository import GraphKeyword
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_di_graph_repository():
|
||||
class Input(BaseModel):
|
||||
s: str
|
||||
p: str
|
||||
o: str
|
||||
|
||||
inputs = [
|
||||
{"s": "main.py:Game:draw", "p": "method:hasDescription", "o": "Draw image"},
|
||||
{"s": "main.py:Game:draw", "p": "method:hasDescription", "o": "Show image"},
|
||||
]
|
||||
path = Path(__file__).parent
|
||||
graph = DiGraphRepository(name="test", root=path)
|
||||
for i in inputs:
|
||||
data = Input(**i)
|
||||
await graph.insert(subject=data.s, predicate=data.p, object_=data.o)
|
||||
v = graph.json()
|
||||
assert v
|
||||
await graph.save()
|
||||
assert graph.pathname.exists()
|
||||
graph.pathname.unlink()
|
||||
|
||||
|
||||
async def test_js_parser():
|
||||
class Input(BaseModel):
|
||||
path: str
|
||||
|
||||
inputs = [
|
||||
{"path": str(Path(__file__).parent / "../../data/code")},
|
||||
]
|
||||
path = Path(__file__).parent
|
||||
graph = DiGraphRepository(name="test", root=path)
|
||||
for i in inputs:
|
||||
data = Input(**i)
|
||||
repo_parser = RepoParser(base_directory=data.path)
|
||||
symbols = repo_parser.generate_symbols()
|
||||
for s in symbols:
|
||||
ns = s.get("file", "")
|
||||
for c in s.get("classes", []):
|
||||
await graph.insert(
|
||||
subject=concat_namespace(ns, c), predicate=GraphKeyword.IS.value, object_=GraphKeyword.CLASS.value
|
||||
)
|
||||
for f in s.get("functions", []):
|
||||
await graph.insert(
|
||||
subject=concat_namespace(ns, f),
|
||||
predicate=GraphKeyword.IS.value,
|
||||
object_=GraphKeyword.FUNCTION.value,
|
||||
)
|
||||
for g in s.get("globals", []):
|
||||
await graph.insert(
|
||||
subject=concat_namespace(ns, g),
|
||||
predicate=GraphKeyword.IS.value,
|
||||
object_=GraphKeyword.GLOBAL_VARIABLE.value,
|
||||
)
|
||||
data = graph.json()
|
||||
assert data
|
||||
|
||||
|
||||
async def test_codes():
|
||||
path = DEFAULT_WORKSPACE_ROOT / "snake_game"
|
||||
repo_parser = RepoParser(base_directory=path)
|
||||
|
||||
graph = DiGraphRepository(name="test", root=path)
|
||||
symbols = repo_parser.generate_symbols()
|
||||
for s in symbols:
|
||||
ns = s.get("file", "")
|
||||
for c in s.get("classes", []):
|
||||
class_name = c.get("name", "")
|
||||
await graph.insert(
|
||||
subject=ns, predicate=GraphKeyword.HAS_CLASS.value, object_=concat_namespace(ns, class_name)
|
||||
)
|
||||
await graph.insert(
|
||||
subject=concat_namespace(ns, class_name),
|
||||
predicate=GraphKeyword.IS.value,
|
||||
object_=GraphKeyword.CLASS.value,
|
||||
)
|
||||
methods = c.get("methods", [])
|
||||
for fn in methods:
|
||||
await graph.insert(
|
||||
subject=concat_namespace(ns, class_name, fn),
|
||||
predicate=GraphKeyword.IS.value,
|
||||
object_=GraphKeyword.CLASS_FUNCTION.value,
|
||||
)
|
||||
for f in s.get("functions", []):
|
||||
await graph.insert(
|
||||
subject=concat_namespace(ns, f), predicate=GraphKeyword.IS.value, object_=GraphKeyword.FUNCTION.value
|
||||
)
|
||||
for g in s.get("globals", []):
|
||||
await graph.insert(
|
||||
subject=concat_namespace(ns, g),
|
||||
predicate=GraphKeyword.IS.value,
|
||||
object_=GraphKeyword.GLOBAL_VARIABLE.value,
|
||||
)
|
||||
data = graph.json()
|
||||
assert data
|
||||
print(data)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-s"])
|
||||
Loading…
Add table
Add a link
Reference in a new issue