mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-04-27 09:46:24 +02:00
rag add objs
This commit is contained in:
parent
cd605bf8f4
commit
a35f13b4c4
3 changed files with 120 additions and 59 deletions
|
|
@ -12,9 +12,10 @@ from llama_index.llms.llm import LLM
|
|||
from llama_index.postprocessor.types import BaseNodePostprocessor
|
||||
from llama_index.query_engine import RetrieverQueryEngine
|
||||
from llama_index.response_synthesizers import BaseSynthesizer
|
||||
from llama_index.schema import NodeWithScore, QueryBundle, QueryType
|
||||
from llama_index.schema import NodeWithScore, QueryBundle, QueryType, TextNode
|
||||
|
||||
from metagpt.rag.factory import get_rankers, get_retriever
|
||||
from metagpt.rag.interface import RAGObject
|
||||
from metagpt.rag.llm import get_default_llm
|
||||
from metagpt.rag.retrievers.base import ModifiableRAGRetriever
|
||||
from metagpt.rag.schema import RankerConfigType, RetrieverConfigType
|
||||
|
|
@ -92,10 +93,15 @@ class SimpleEngine(RetrieverQueryEngine):
|
|||
return await super().aretrieve(query_bundle)
|
||||
|
||||
def add_docs(self, input_files: list[str]):
|
||||
"""Add docs to retriever. retriever must has add_nodes func"""
|
||||
"""Add docs to retriever. retriever must has add_nodes func."""
|
||||
if not isinstance(self.retriever, ModifiableRAGRetriever):
|
||||
raise TypeError(f"must be inplement to add_docs: {type(self.retriever)}")
|
||||
|
||||
documents = SimpleDirectoryReader(input_files=input_files).load_data()
|
||||
nodes = self.index.service_context.node_parser.get_nodes_from_documents(documents)
|
||||
self.retriever.add_nodes(nodes)
|
||||
|
||||
def add_objs(self, obj_list: list[RAGObject]):
|
||||
"""Adds objects to the retriever, storing each object's original form in metadata for future reference."""
|
||||
nodes = [TextNode(text=obj.rag_key(), metadata={"obj": obj}) for obj in obj_list]
|
||||
self.retriever.add_nodes(nodes)
|
||||
|
|
|
|||
6
metagpt/rag/interface.py
Normal file
6
metagpt/rag/interface.py
Normal file
|
|
@ -0,0 +1,6 @@
|
|||
from typing import Protocol
|
||||
|
||||
|
||||
class RAGObject(Protocol):
|
||||
def rag_key(self) -> str:
|
||||
"""for rag search"""
|
||||
Loading…
Add table
Add a link
Reference in a new issue