diff --git a/metagpt/actions/graph_db_action.py b/metagpt/actions/graph_db_action.py new file mode 100644 index 000000000..d050760f6 --- /dev/null +++ b/metagpt/actions/graph_db_action.py @@ -0,0 +1,43 @@ +""" +graph_db_action.py + +This module defines the GraphDBAction class, which interacts with a graph database. + +Classes: + GraphDBAction: An action class that interacts with a graph database. + +Usage Example: + action = GraphDBAction() + await action.load_graph_db('path/to/graph_db') + + action = GraphDBAction(graph_db=external_graph_db) +""" + +from __future__ import annotations + +from pathlib import Path +from typing import Optional + +from metagpt.actions import Action +from metagpt.utils.di_graph_repository import DiGraphRepository +from metagpt.utils.graph_repository import GraphRepository + + +class GraphDBAction(Action): + """ + An action class that interacts with a graph database. + + Attributes: + graph_db (Optional[GraphRepository]): The graph database instance. + """ + + graph_db: Optional[GraphRepository] = None + + async def load_graph_db(self, pathname: str | Path): + """ + Asynchronously loads the graph database from the specified path. + + Args: + pathname (str | Path): The path to the graph database file. + """ + self.graph_db = await DiGraphRepository.load_from(pathname) diff --git a/metagpt/utils/di_graph_repository.py b/metagpt/utils/di_graph_repository.py index fee706ece..f8fabfbdc 100644 --- a/metagpt/utils/di_graph_repository.py +++ b/metagpt/utils/di_graph_repository.py @@ -23,8 +23,8 @@ from metagpt.utils.graph_repository import SPO, GraphRepository class DiGraphRepository(GraphRepository): """Graph repository based on DiGraph.""" - def __init__(self, name: str, **kwargs): - super().__init__(name=name, **kwargs) + def __init__(self, name: str | Path, **kwargs): + super().__init__(name=str(name), **kwargs) self._repo = networkx.DiGraph() async def insert(self, subject: str, predicate: str, object_: str): @@ -112,8 +112,14 @@ class DiGraphRepository(GraphRepository): async def load(self, pathname: str | Path): """Load a directed graph repository from a JSON file.""" data = await aread(filename=pathname, encoding="utf-8") - m = json.loads(data) + self.load_json(data) + + def load_json(self, val: str): + if not val: + return self + m = json.loads(val) self._repo = networkx.node_link_graph(m) + return self @staticmethod async def load_from(pathname: str | Path) -> GraphRepository: @@ -126,9 +132,7 @@ class DiGraphRepository(GraphRepository): GraphRepository: A new instance of the graph repository loaded from the specified JSON file. """ pathname = Path(pathname) - name = pathname.with_suffix("").name - root = pathname.parent - graph = DiGraphRepository(name=name, root=root) + graph = DiGraphRepository(name=pathname.stem, root=pathname.parent) if pathname.exists(): await graph.load(pathname=pathname) return graph