feat: sync metagpt/utils/di_graph_repository.py

feat: +graph db action
This commit is contained in:
莘权 马 2024-06-13 12:00:57 +08:00
parent 8aa4c970d9
commit e8cc6b7193
2 changed files with 53 additions and 6 deletions

View file

@ -0,0 +1,43 @@
"""
graph_db_action.py
This module defines the GraphDBAction class, which interacts with a graph database.
Classes:
GraphDBAction: An action class that interacts with a graph database.
Usage Example:
action = GraphDBAction()
await action.load_graph_db('path/to/graph_db')
action = GraphDBAction(graph_db=external_graph_db)
"""
from __future__ import annotations
from pathlib import Path
from typing import Optional
from metagpt.actions import Action
from metagpt.utils.di_graph_repository import DiGraphRepository
from metagpt.utils.graph_repository import GraphRepository
class GraphDBAction(Action):
"""
An action class that interacts with a graph database.
Attributes:
graph_db (Optional[GraphRepository]): The graph database instance.
"""
graph_db: Optional[GraphRepository] = None
async def load_graph_db(self, pathname: str | Path):
"""
Asynchronously loads the graph database from the specified path.
Args:
pathname (str | Path): The path to the graph database file.
"""
self.graph_db = await DiGraphRepository.load_from(pathname)

View file

@ -23,8 +23,8 @@ from metagpt.utils.graph_repository import SPO, GraphRepository
class DiGraphRepository(GraphRepository):
"""Graph repository based on DiGraph."""
def __init__(self, name: str, **kwargs):
super().__init__(name=name, **kwargs)
def __init__(self, name: str | Path, **kwargs):
super().__init__(name=str(name), **kwargs)
self._repo = networkx.DiGraph()
async def insert(self, subject: str, predicate: str, object_: str):
@ -112,8 +112,14 @@ class DiGraphRepository(GraphRepository):
async def load(self, pathname: str | Path):
"""Load a directed graph repository from a JSON file."""
data = await aread(filename=pathname, encoding="utf-8")
m = json.loads(data)
self.load_json(data)
def load_json(self, val: str):
if not val:
return self
m = json.loads(val)
self._repo = networkx.node_link_graph(m)
return self
@staticmethod
async def load_from(pathname: str | Path) -> GraphRepository:
@ -126,9 +132,7 @@ class DiGraphRepository(GraphRepository):
GraphRepository: A new instance of the graph repository loaded from the specified JSON file.
"""
pathname = Path(pathname)
name = pathname.with_suffix("").name
root = pathname.parent
graph = DiGraphRepository(name=name, root=root)
graph = DiGraphRepository(name=pathname.stem, root=pathname.parent)
if pathname.exists():
await graph.load(pathname=pathname)
return graph