rag add objs

This commit is contained in:
seehi 2024-02-07 18:19:22 +08:00 committed by betterwang
parent cd605bf8f4
commit a35f13b4c4
3 changed files with 120 additions and 59 deletions

View file

@ -12,9 +12,10 @@ from llama_index.llms.llm import LLM
from llama_index.postprocessor.types import BaseNodePostprocessor
from llama_index.query_engine import RetrieverQueryEngine
from llama_index.response_synthesizers import BaseSynthesizer
from llama_index.schema import NodeWithScore, QueryBundle, QueryType
from llama_index.schema import NodeWithScore, QueryBundle, QueryType, TextNode
from metagpt.rag.factory import get_rankers, get_retriever
from metagpt.rag.interface import RAGObject
from metagpt.rag.llm import get_default_llm
from metagpt.rag.retrievers.base import ModifiableRAGRetriever
from metagpt.rag.schema import RankerConfigType, RetrieverConfigType
@ -92,10 +93,15 @@ class SimpleEngine(RetrieverQueryEngine):
return await super().aretrieve(query_bundle)
def add_docs(self, input_files: list[str]):
"""Add docs to retriever. retriever must has add_nodes func"""
"""Add docs to retriever. retriever must has add_nodes func."""
if not isinstance(self.retriever, ModifiableRAGRetriever):
raise TypeError(f"must be inplement to add_docs: {type(self.retriever)}")
documents = SimpleDirectoryReader(input_files=input_files).load_data()
nodes = self.index.service_context.node_parser.get_nodes_from_documents(documents)
self.retriever.add_nodes(nodes)
def add_objs(self, obj_list: list[RAGObject]):
"""Adds objects to the retriever, storing each object's original form in metadata for future reference."""
nodes = [TextNode(text=obj.rag_key(), metadata={"obj": obj}) for obj in obj_list]
self.retriever.add_nodes(nodes)

6
metagpt/rag/interface.py Normal file
View file

@ -0,0 +1,6 @@
from typing import Protocol
class RAGObject(Protocol):
def rag_key(self) -> str:
"""for rag search"""