rag pipeline

This commit is contained in:
seehi 2024-01-30 20:19:50 +08:00
parent 4fcf724797
commit 916b139e2b
18 changed files with 372 additions and 15 deletions

0
metagpt/rag/__init__.py Normal file
View file

View file

@ -0,0 +1,3 @@
from metagpt.rag.engines.simple import SimpleEngine
__all__ = ["SimpleEngine"]

View file

@ -0,0 +1,48 @@
"""Simple Engine."""
from typing import Optional
from llama_index import ServiceContext, SimpleDirectoryReader, VectorStoreIndex
from llama_index.constants import DEFAULT_SIMILARITY_TOP_K
from llama_index.embeddings.base import BaseEmbedding
from llama_index.llms.llm import LLM
from llama_index.query_engine import RetrieverQueryEngine
from llama_index.retrievers import VectorIndexRetriever
from metagpt.rag.llm import get_default_llm
from metagpt.utils.embedding import get_embedding
class SimpleEngine(RetrieverQueryEngine):
"""
SimpleEngine is a search engine that uses a vector index for retrieving documents.
"""
@classmethod
def from_docs(
cls,
input_dir: str = None,
input_files: list = None,
embed_model: BaseEmbedding = None,
llm: LLM = None,
# node parser kwargs
chunk_size: Optional[int] = None,
chunk_overlap: Optional[int] = None,
# retrieve kwargs
similarity_top_k: int = DEFAULT_SIMILARITY_TOP_K,
) -> "SimpleEngine":
"""This engine is designed to be simple and straightforward"""
documents = SimpleDirectoryReader(input_dir=input_dir, input_files=input_files).load_data()
service_context = ServiceContext.from_defaults(
embed_model=embed_model or get_embedding(),
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
llm=llm or get_default_llm(),
)
index = VectorStoreIndex.from_documents(documents, service_context=service_context)
retriever = VectorIndexRetriever(index=index, similarity_top_k=similarity_top_k)
return SimpleEngine(retriever=retriever)
async def asearch(self, content: str, **kwargs) -> str:
"""Inplement tools.SearchInterface"""
return await self.aquery(content)

7
metagpt/rag/llm.py Normal file
View file

@ -0,0 +1,7 @@
from llama_index.llms import OpenAI
from metagpt.config2 import config
def get_default_llm() -> OpenAI:
return OpenAI(api_base=config.llm.base_url, api_key=config.llm.api_key)

View file

View file

@ -0,0 +1,20 @@
"""Base Ranker."""
from abc import abstractmethod
from typing import Optional
from llama_index import QueryBundle
from llama_index.postprocessor.types import BaseNodePostprocessor
from llama_index.schema import NodeWithScore
class RAGRanker(BaseNodePostprocessor):
"""inherit from llama_index"""
@abstractmethod
def _postprocess_nodes(
self,
nodes: list[NodeWithScore],
query_bundle: Optional[QueryBundle] = None,
) -> list[NodeWithScore]:
"""postprocess nodes."""

View file

@ -0,0 +1,4 @@
"""init"""
from metagpt.rag.retrievers.hybrid import SimpleHybridRetriever
__all__ = ["SimpleHybridRetriever"]

View file

@ -0,0 +1,18 @@
"""Base retriever."""
from abc import abstractmethod
from llama_index.retrievers import BaseRetriever
from llama_index.schema import NodeWithScore, QueryType
class RAGRetriever(BaseRetriever):
"""inherit from llama_index"""
@abstractmethod
async def _aretrieve(self, query: QueryType) -> list[NodeWithScore]:
"""retrieve nodes"""
def _retrieve(self, query: QueryType) -> list[NodeWithScore]:
"""retrieve nodes"""

View file

@ -0,0 +1,36 @@
"""Hybrid retriever."""
from llama_index.schema import QueryType
from metagpt.rag.retrievers.base import RAGRetriever
class SimpleHybridRetriever(RAGRetriever):
"""
SimpleHybridRetriever is a composite retriever that aggregates search results from multiple retrievers.
"""
def __init__(self, *retrievers):
self.retrievers: list[RAGRetriever] = retrievers
super().__init__()
async def _aretrieve(self, query: QueryType, **kwargs):
"""
Asynchronously retrieves and aggregates search results from all configured retrievers.
This method queries each retriever in the `retrievers` list with the given query and
additional keyword arguments. It then combines the results, ensuring that each node is
unique, based on the node's ID.
"""
all_nodes = []
for retriever in self.retrievers:
nodes = await retriever.aretrieve(query, **kwargs)
all_nodes.extend(nodes)
# combine all nodes
result = []
node_ids = set()
for n in all_nodes:
if n.node.node_id not in node_ids:
result.append(n)
node_ids.add(n.node.node_id)
return result