rag add objs

This commit is contained in:
seehi 2024-02-07 21:40:41 +08:00
parent a98da52c0e
commit ab045ccacd

View file

@ -94,14 +94,19 @@ class SimpleEngine(RetrieverQueryEngine):
def add_docs(self, input_files: list[str]):
"""Add docs to retriever. retriever must has add_nodes func."""
if not isinstance(self.retriever, ModifiableRAGRetriever):
raise TypeError(f"must be inplement to add_docs: {type(self.retriever)}")
self._ensure_retriever_modifiable()
documents = SimpleDirectoryReader(input_files=input_files).load_data()
nodes = self.index.service_context.node_parser.get_nodes_from_documents(documents)
self.retriever.add_nodes(nodes)
def add_objs(self, obj_list: list[RAGObject]):
def add_objs(self, objs: list[RAGObject]):
"""Adds objects to the retriever, storing each object's original form in metadata for future reference."""
nodes = [TextNode(text=obj.rag_key(), metadata={"obj": obj}) for obj in obj_list]
self._ensure_retriever_modifiable()
nodes = [TextNode(text=obj.rag_key(), metadata={"obj": obj}) for obj in objs]
self.retriever.add_nodes(nodes)
def _ensure_retriever_modifiable(self):
if not isinstance(self.retriever, ModifiableRAGRetriever):
raise TypeError(f"the retriever is not modifiable: {type(self.retriever)}")