add rag pipeline unittest

This commit is contained in:
seehi 2024-02-06 20:15:03 +08:00
parent 254088b026
commit a4c095300c
12 changed files with 355 additions and 85 deletions

View file

@ -1,10 +1,17 @@
"""Simple Engine."""
from typing import Optional
from llama_index import ServiceContext, SimpleDirectoryReader, VectorStoreIndex
from llama_index.callbacks.base import CallbackManager
from llama_index.core.base_retriever import BaseRetriever
from llama_index.embeddings.base import BaseEmbedding
from llama_index.indices.base import BaseIndex
from llama_index.llms.llm import LLM
from llama_index.postprocessor.types import BaseNodePostprocessor
from llama_index.query_engine import RetrieverQueryEngine
from llama_index.response_synthesizers import BaseSynthesizer
from llama_index.schema import NodeWithScore, QueryBundle, QueryType
from metagpt.rag.llm import get_default_llm
@ -16,6 +23,29 @@ from metagpt.utils.embedding import get_embedding
class SimpleEngine(RetrieverQueryEngine):
"""
SimpleEngine is a lightweight and easy-to-use search engine that integrates
document reading, embedding, indexing, retrieving, and ranking functionalities
into a single, straightforward workflow. It is designed to quickly set up a
search engine from a collection of documents.
"""
def __init__(
self,
retriever: BaseRetriever,
response_synthesizer: Optional[BaseSynthesizer] = None,
node_postprocessors: Optional[list[BaseNodePostprocessor]] = None,
callback_manager: Optional[CallbackManager] = None,
index: Optional[BaseIndex] = None,
) -> None:
super().__init__(
retriever=retriever,
response_synthesizer=response_synthesizer,
node_postprocessors=node_postprocessors,
callback_manager=callback_manager,
)
self.index = index
@classmethod
def from_docs(
cls,
@ -31,9 +61,14 @@ class SimpleEngine(RetrieverQueryEngine):
"""This engine is designed to be simple and straightforward
Args:
input_dir (str): Path to the directory.
input_files (list): List of file paths to read
(Optional; overrides input_dir, exclude)
input_dir: Path to the directory.
input_files: List of file paths to read (Optional; overrides input_dir, exclude).
llm: Must supported by llama index.
embed_model: Must supported by llama index.
chunk_size: The size of text chunks (in tokens) to split documents into for embedding.
chunk_overlap: The number of tokens for overlapping between consecutive chunks. Helps in maintaining context continuity.
retriever_configs: Configuration for retrievers. If more than one config, will use SimpleHybridRetriever.
ranker_configs: Configuration for rankers.
"""
documents = SimpleDirectoryReader(input_dir=input_dir, input_files=input_files).load_data()
service_context = ServiceContext.from_defaults(
@ -46,7 +81,7 @@ class SimpleEngine(RetrieverQueryEngine):
retriever = get_retriever(index, configs=retriever_configs)
rankers = get_rankers(configs=ranker_configs, service_context=service_context)
return SimpleEngine(retriever=retriever, node_postprocessors=rankers)
return cls(retriever=retriever, node_postprocessors=rankers, index=index)
async def asearch(self, content: str, **kwargs) -> str:
"""Inplement tools.SearchInterface"""
@ -58,6 +93,8 @@ class SimpleEngine(RetrieverQueryEngine):
return await super().aretrieve(query_bundle)
def add_docs(self, input_files: list[str]):
"""Add docs to retriever"""
documents = SimpleDirectoryReader(input_files=input_files).load_data()
nodes = self.index.service_context.node_parser.get_nodes_from_documents(documents)
retriever: RAGRetriever = self.retriever
retriever.add_docs(documents)
retriever.add_nodes(nodes)

View file

@ -1,3 +1,4 @@
"""Rankers Factory"""
from llama_index import ServiceContext
from llama_index.postprocessor import LLMRerank
from llama_index.postprocessor.types import BaseNodePostprocessor
@ -19,18 +20,18 @@ class RankerFactory:
return [self._get_ranker(config, service_context) for config in configs]
def _default_ranker(self, service_context: ServiceContext = None):
def _default_ranker(self, service_context: ServiceContext = None) -> LLMRerank:
return LLMRerank(top_n=LLMRankerConfig().top_n, service_context=service_context)
def _get_ranker(self, config: RankerConfigType, service_context: ServiceContext = None):
def _get_ranker(self, config: RankerConfigType, service_context: ServiceContext = None) -> BaseNodePostprocessor:
create_func = self.ranker_creators.get(type(config))
if create_func:
return create_func(config, service_context)
raise ValueError(f"Unknown ranker config: {config}")
def _create_llm_ranker(self, config, service_context=None):
return LLMRerank(top_n=config.top_n, service_context=service_context)
def _create_llm_ranker(self, config: LLMRankerConfig, service_context=None) -> LLMRerank:
return LLMRerank(**config.model_dump(), service_context=service_context)
get_rankers = RankerFactory().get_rankers

View file

@ -3,21 +3,19 @@
from abc import abstractmethod
from llama_index import Document
from llama_index.retrievers import BaseRetriever
from llama_index.schema import NodeWithScore, QueryType
from llama_index.schema import BaseNode, NodeWithScore, QueryType
class RAGRetriever(BaseRetriever):
"""inherit from llama_index"""
"""Inherit from llama_index"""
@abstractmethod
async def _aretrieve(self, query: QueryType) -> list[NodeWithScore]:
"""retrieve nodes"""
@abstractmethod
def add_docs(self, documents: list[Document]) -> None:
"""add docs"""
"""Retrieve nodes"""
def _retrieve(self, query: QueryType) -> list[NodeWithScore]:
"""retrieve nodes"""
"""Retrieve nodes"""
def add_nodes(self, nodes: list[BaseNode], **kwargs) -> None:
"""To support add docs, must inplement this func"""

View file

@ -1,14 +1,14 @@
from llama_index import Document
from llama_index.retrievers import BM25Retriever
from llama_index.schema import BaseNode
class DynamicBM25Retriever(BM25Retriever):
def add_docs(self, documents: list[Document]):
def add_nodes(self, nodes: list[BaseNode], **kwargs):
try:
from rank_bm25 import BM25Okapi
except ImportError:
raise ImportError("Please install rank_bm25: pip install rank-bm25")
self._nodes.extend(documents)
self._nodes.extend(nodes)
self._corpus = [self._tokenizer(node.get_content()) for node in self._nodes]
self.bm25 = BM25Okapi(self._corpus)

View file

@ -1,3 +1,4 @@
"""Retriever Factory"""
import faiss
from llama_index import StorageContext, VectorStoreIndex
from llama_index.indices.base import BaseIndex
@ -22,6 +23,7 @@ class RetrieverFactory:
}
def get_retriever(self, index: BaseIndex, configs: list[RetrieverConfigType] = None) -> RAGRetriever:
"""Creates and returns a retriever instance based on the provided configurations."""
if not configs:
return self._default_retriever(index)

View file

@ -1,8 +1,7 @@
from llama_index import Document
from llama_index.retrievers import VectorIndexRetriever
from llama_index.schema import BaseNode
class FAISSRetriever(VectorIndexRetriever):
def add_docs(self, documents: list[Document]):
for document in documents:
self._index.insert(document)
def add_nodes(self, nodes: list[BaseNode], **kwargs):
self._index.insert_nodes(nodes, **kwargs)

View file

@ -1,6 +1,6 @@
"""Hybrid retriever."""
from llama_index import Document, ServiceContext
from llama_index.schema import QueryType
from llama_index import ServiceContext
from llama_index.schema import BaseNode, QueryType
from metagpt.rag.retrievers.base import RAGRetriever
@ -37,6 +37,6 @@ class SimpleHybridRetriever(RAGRetriever):
node_ids.add(n.node.node_id)
return result
def add_docs(self, documents: list[Document]):
def add_nodes(self, nodes: list[BaseNode]):
for r in self.retrievers:
r.add_docs(documents)
r.add_nodes(nodes)