reconstruct object in rag node

This commit is contained in:
seehi 2024-03-07 12:43:43 +08:00
parent af63eab13c
commit f149007752

View file

@ -134,7 +134,7 @@ class SimpleEngine(RetrieverQueryEngine):
query_bundle = QueryBundle(query) if isinstance(query, str) else query
nodes = await super().aretrieve(query_bundle)
self._try_reconstruct_object(nodes)
self._try_reconstruct_obj(nodes)
return nodes
def add_docs(self, input_files: list[str]):
@ -149,20 +149,18 @@ class SimpleEngine(RetrieverQueryEngine):
"""Adds objects to the retriever, storing each object's original form in metadata for future reference."""
self._ensure_retriever_modifiable()
nodes = [
TextNode(
text=obj.rag_key(),
metadata={
"is_obj": True,
"obj_dict": obj.model_dump(),
"obj_cls_name": obj.__class__.__name__,
"obj_mod_name": obj.__class__.__module__,
},
)
for obj in objs
]
nodes = [TextNode(text=obj.rag_key(), metadata=self._get_obj_metadata(obj)) for obj in objs]
self._save_nodes(nodes)
def _get_obj_metadata(self, obj: RAGObject) -> dict:
metadata = {
"is_obj": True,
"obj_dict": obj.model_dump(),
"obj_cls_name": obj.__class__.__name__,
"obj_mod_name": obj.__class__.__module__,
}
return metadata
def _ensure_retriever_modifiable(self):
if not isinstance(self.retriever, ModifiableRAGRetriever):
raise TypeError(f"the retriever is not modifiable: {type(self.retriever)}")
@ -175,7 +173,7 @@ class SimpleEngine(RetrieverQueryEngine):
self.index.insert_nodes(nodes)
@staticmethod
def _try_reconstruct_object(nodes: list[NodeWithScore]):
def _try_reconstruct_obj(nodes: list[NodeWithScore]):
"""If node is object, then dynamically reconstruct object, and save object to node.metadata["obj"]."""
for node in nodes:
if node.metadata.get("is_obj"):