mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-05-09 07:42:38 +02:00
reconstruct object in rag node
This commit is contained in:
parent
800054aae6
commit
af63eab13c
2 changed files with 39 additions and 15 deletions
|
|
@ -36,6 +36,7 @@ from metagpt.rag.factories import (
|
|||
from metagpt.rag.interface import RAGObject
|
||||
from metagpt.rag.retrievers.base import ModifiableRAGRetriever
|
||||
from metagpt.rag.schema import BaseIndexConfig, BaseRankerConfig, BaseRetrieverConfig
|
||||
from metagpt.utils.common import import_class
|
||||
|
||||
|
||||
class SimpleEngine(RetrieverQueryEngine):
|
||||
|
|
@ -129,9 +130,12 @@ class SimpleEngine(RetrieverQueryEngine):
|
|||
return await self.aquery(content)
|
||||
|
||||
async def aretrieve(self, query: QueryType) -> list[NodeWithScore]:
|
||||
"""Allow query to be str"""
|
||||
"""Allow query to be str."""
|
||||
query_bundle = QueryBundle(query) if isinstance(query, str) else query
|
||||
return await super().aretrieve(query_bundle)
|
||||
|
||||
nodes = await super().aretrieve(query_bundle)
|
||||
self._try_reconstruct_object(nodes)
|
||||
return nodes
|
||||
|
||||
def add_docs(self, input_files: list[str]):
|
||||
"""Add docs to retriever. retriever must has add_nodes func."""
|
||||
|
|
@ -145,7 +149,18 @@ class SimpleEngine(RetrieverQueryEngine):
|
|||
"""Adds objects to the retriever, storing each object's original form in metadata for future reference."""
|
||||
self._ensure_retriever_modifiable()
|
||||
|
||||
nodes = [TextNode(text=obj.rag_key(), metadata={"obj": obj.model_dump()}) for obj in objs]
|
||||
nodes = [
|
||||
TextNode(
|
||||
text=obj.rag_key(),
|
||||
metadata={
|
||||
"is_obj": True,
|
||||
"obj_dict": obj.model_dump(),
|
||||
"obj_cls_name": obj.__class__.__name__,
|
||||
"obj_mod_name": obj.__class__.__module__,
|
||||
},
|
||||
)
|
||||
for obj in objs
|
||||
]
|
||||
self._save_nodes(nodes)
|
||||
|
||||
def _ensure_retriever_modifiable(self):
|
||||
|
|
@ -158,3 +173,11 @@ class SimpleEngine(RetrieverQueryEngine):
|
|||
|
||||
# for persist
|
||||
self.index.insert_nodes(nodes)
|
||||
|
||||
@staticmethod
|
||||
def _try_reconstruct_object(nodes: list[NodeWithScore]):
|
||||
"""If node is object, then dynamically reconstruct object, and save object to node.metadata["obj"]."""
|
||||
for node in nodes:
|
||||
if node.metadata.get("is_obj"):
|
||||
obj_cls = import_class(node.metadata["obj_cls_name"], node.metadata["obj_mod_name"])
|
||||
node.metadata["obj"] = obj_cls(**node.metadata["obj_dict"])
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue