mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-05-06 22:32:38 +02:00
rag pipeline
This commit is contained in:
parent
4fcf724797
commit
916b139e2b
18 changed files with 372 additions and 15 deletions
0
metagpt/rag/__init__.py
Normal file
0
metagpt/rag/__init__.py
Normal file
3
metagpt/rag/engines/__init__.py
Normal file
3
metagpt/rag/engines/__init__.py
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
from metagpt.rag.engines.simple import SimpleEngine
|
||||
|
||||
__all__ = ["SimpleEngine"]
|
||||
48
metagpt/rag/engines/simple.py
Normal file
48
metagpt/rag/engines/simple.py
Normal file
|
|
@ -0,0 +1,48 @@
|
|||
"""Simple Engine."""
|
||||
from typing import Optional
|
||||
|
||||
from llama_index import ServiceContext, SimpleDirectoryReader, VectorStoreIndex
|
||||
from llama_index.constants import DEFAULT_SIMILARITY_TOP_K
|
||||
from llama_index.embeddings.base import BaseEmbedding
|
||||
from llama_index.llms.llm import LLM
|
||||
from llama_index.query_engine import RetrieverQueryEngine
|
||||
from llama_index.retrievers import VectorIndexRetriever
|
||||
|
||||
from metagpt.rag.llm import get_default_llm
|
||||
from metagpt.utils.embedding import get_embedding
|
||||
|
||||
|
||||
class SimpleEngine(RetrieverQueryEngine):
|
||||
"""
|
||||
SimpleEngine is a search engine that uses a vector index for retrieving documents.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def from_docs(
|
||||
cls,
|
||||
input_dir: str = None,
|
||||
input_files: list = None,
|
||||
embed_model: BaseEmbedding = None,
|
||||
llm: LLM = None,
|
||||
# node parser kwargs
|
||||
chunk_size: Optional[int] = None,
|
||||
chunk_overlap: Optional[int] = None,
|
||||
# retrieve kwargs
|
||||
similarity_top_k: int = DEFAULT_SIMILARITY_TOP_K,
|
||||
) -> "SimpleEngine":
|
||||
"""This engine is designed to be simple and straightforward"""
|
||||
documents = SimpleDirectoryReader(input_dir=input_dir, input_files=input_files).load_data()
|
||||
service_context = ServiceContext.from_defaults(
|
||||
embed_model=embed_model or get_embedding(),
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap,
|
||||
llm=llm or get_default_llm(),
|
||||
)
|
||||
index = VectorStoreIndex.from_documents(documents, service_context=service_context)
|
||||
retriever = VectorIndexRetriever(index=index, similarity_top_k=similarity_top_k)
|
||||
|
||||
return SimpleEngine(retriever=retriever)
|
||||
|
||||
async def asearch(self, content: str, **kwargs) -> str:
|
||||
"""Inplement tools.SearchInterface"""
|
||||
return await self.aquery(content)
|
||||
7
metagpt/rag/llm.py
Normal file
7
metagpt/rag/llm.py
Normal file
|
|
@ -0,0 +1,7 @@
|
|||
from llama_index.llms import OpenAI
|
||||
|
||||
from metagpt.config2 import config
|
||||
|
||||
|
||||
def get_default_llm() -> OpenAI:
|
||||
return OpenAI(api_base=config.llm.base_url, api_key=config.llm.api_key)
|
||||
0
metagpt/rag/rankers/__init__.py
Normal file
0
metagpt/rag/rankers/__init__.py
Normal file
20
metagpt/rag/rankers/base.py
Normal file
20
metagpt/rag/rankers/base.py
Normal file
|
|
@ -0,0 +1,20 @@
|
|||
"""Base Ranker."""
|
||||
|
||||
from abc import abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
from llama_index import QueryBundle
|
||||
from llama_index.postprocessor.types import BaseNodePostprocessor
|
||||
from llama_index.schema import NodeWithScore
|
||||
|
||||
|
||||
class RAGRanker(BaseNodePostprocessor):
|
||||
"""inherit from llama_index"""
|
||||
|
||||
@abstractmethod
|
||||
def _postprocess_nodes(
|
||||
self,
|
||||
nodes: list[NodeWithScore],
|
||||
query_bundle: Optional[QueryBundle] = None,
|
||||
) -> list[NodeWithScore]:
|
||||
"""postprocess nodes."""
|
||||
4
metagpt/rag/retrievers/__init__.py
Normal file
4
metagpt/rag/retrievers/__init__.py
Normal file
|
|
@ -0,0 +1,4 @@
|
|||
"""init"""
|
||||
from metagpt.rag.retrievers.hybrid import SimpleHybridRetriever
|
||||
|
||||
__all__ = ["SimpleHybridRetriever"]
|
||||
18
metagpt/rag/retrievers/base.py
Normal file
18
metagpt/rag/retrievers/base.py
Normal file
|
|
@ -0,0 +1,18 @@
|
|||
"""Base retriever."""
|
||||
|
||||
|
||||
from abc import abstractmethod
|
||||
|
||||
from llama_index.retrievers import BaseRetriever
|
||||
from llama_index.schema import NodeWithScore, QueryType
|
||||
|
||||
|
||||
class RAGRetriever(BaseRetriever):
|
||||
"""inherit from llama_index"""
|
||||
|
||||
@abstractmethod
|
||||
async def _aretrieve(self, query: QueryType) -> list[NodeWithScore]:
|
||||
"""retrieve nodes"""
|
||||
|
||||
def _retrieve(self, query: QueryType) -> list[NodeWithScore]:
|
||||
"""retrieve nodes"""
|
||||
36
metagpt/rag/retrievers/hybrid.py
Normal file
36
metagpt/rag/retrievers/hybrid.py
Normal file
|
|
@ -0,0 +1,36 @@
|
|||
"""Hybrid retriever."""
|
||||
from llama_index.schema import QueryType
|
||||
|
||||
from metagpt.rag.retrievers.base import RAGRetriever
|
||||
|
||||
|
||||
class SimpleHybridRetriever(RAGRetriever):
|
||||
"""
|
||||
SimpleHybridRetriever is a composite retriever that aggregates search results from multiple retrievers.
|
||||
"""
|
||||
|
||||
def __init__(self, *retrievers):
|
||||
self.retrievers: list[RAGRetriever] = retrievers
|
||||
super().__init__()
|
||||
|
||||
async def _aretrieve(self, query: QueryType, **kwargs):
|
||||
"""
|
||||
Asynchronously retrieves and aggregates search results from all configured retrievers.
|
||||
|
||||
This method queries each retriever in the `retrievers` list with the given query and
|
||||
additional keyword arguments. It then combines the results, ensuring that each node is
|
||||
unique, based on the node's ID.
|
||||
"""
|
||||
all_nodes = []
|
||||
for retriever in self.retrievers:
|
||||
nodes = await retriever.aretrieve(query, **kwargs)
|
||||
all_nodes.extend(nodes)
|
||||
|
||||
# combine all nodes
|
||||
result = []
|
||||
node_ids = set()
|
||||
for n in all_nodes:
|
||||
if n.node.node_id not in node_ids:
|
||||
result.append(n)
|
||||
node_ids.add(n.node.node_id)
|
||||
return result
|
||||
Loading…
Add table
Add a link
Reference in a new issue