reconstruct object in rag node

This commit is contained in:
seehi 2024-03-07 12:07:27 +08:00
parent 800054aae6
commit af63eab13c
2 changed files with 39 additions and 15 deletions

View file

@ -36,6 +36,7 @@ from metagpt.rag.factories import (
from metagpt.rag.interface import RAGObject
from metagpt.rag.retrievers.base import ModifiableRAGRetriever
from metagpt.rag.schema import BaseIndexConfig, BaseRankerConfig, BaseRetrieverConfig
from metagpt.utils.common import import_class
class SimpleEngine(RetrieverQueryEngine):
@ -129,9 +130,12 @@ class SimpleEngine(RetrieverQueryEngine):
return await self.aquery(content)
async def aretrieve(self, query: QueryType) -> list[NodeWithScore]:
"""Allow query to be str"""
"""Allow query to be str."""
query_bundle = QueryBundle(query) if isinstance(query, str) else query
return await super().aretrieve(query_bundle)
nodes = await super().aretrieve(query_bundle)
self._try_reconstruct_object(nodes)
return nodes
def add_docs(self, input_files: list[str]):
"""Add docs to retriever. retriever must has add_nodes func."""
@ -145,7 +149,18 @@ class SimpleEngine(RetrieverQueryEngine):
"""Adds objects to the retriever, storing each object's original form in metadata for future reference."""
self._ensure_retriever_modifiable()
nodes = [TextNode(text=obj.rag_key(), metadata={"obj": obj.model_dump()}) for obj in objs]
nodes = [
TextNode(
text=obj.rag_key(),
metadata={
"is_obj": True,
"obj_dict": obj.model_dump(),
"obj_cls_name": obj.__class__.__name__,
"obj_mod_name": obj.__class__.__module__,
},
)
for obj in objs
]
self._save_nodes(nodes)
def _ensure_retriever_modifiable(self):
@ -158,3 +173,11 @@ class SimpleEngine(RetrieverQueryEngine):
# for persist
self.index.insert_nodes(nodes)
@staticmethod
def _try_reconstruct_object(nodes: list[NodeWithScore]):
"""If node is object, then dynamically reconstruct object, and save object to node.metadata["obj"]."""
for node in nodes:
if node.metadata.get("is_obj"):
obj_cls = import_class(node.metadata["obj_cls_name"], node.metadata["obj_mod_name"])
node.metadata["obj"] = obj_cls(**node.metadata["obj_dict"])