rag pipeline

This commit is contained in:
seehi 2024-02-02 23:04:38 +08:00
parent 0b0be04cf1
commit 3ae422193d
8 changed files with 227 additions and 148 deletions

View file

@ -1,14 +1,15 @@
"""Simple Engine."""
from typing import Optional
from llama_index import ServiceContext, SimpleDirectoryReader, VectorStoreIndex
from llama_index.constants import DEFAULT_SIMILARITY_TOP_K
from llama_index import ServiceContext, SimpleDirectoryReader
from llama_index.embeddings.base import BaseEmbedding
from llama_index.llms.llm import LLM
from llama_index.query_engine import RetrieverQueryEngine
from llama_index.retrievers import VectorIndexRetriever
from llama_index.schema import NodeWithScore, QueryBundle, QueryType
from metagpt.rag.llm import get_default_llm
from metagpt.rag.rankers import get_rankers
from metagpt.rag.retrievers import get_retriever
from metagpt.rag.schema import RankerConfig, RetrieverConfig
from metagpt.utils.embedding import get_embedding
@ -22,27 +23,38 @@ class SimpleEngine(RetrieverQueryEngine):
cls,
input_dir: str = None,
input_files: list = None,
embed_model: BaseEmbedding = None,
llm: LLM = None,
# node parser kwargs
chunk_size: Optional[int] = None,
chunk_overlap: Optional[int] = None,
# retrieve kwargs
similarity_top_k: int = DEFAULT_SIMILARITY_TOP_K,
embed_model: BaseEmbedding = None,
chunk_size: int = None,
chunk_overlap: int = None,
retriever_configs: list[RetrieverConfig] = None,
ranker_configs: list[RankerConfig] = None,
) -> "SimpleEngine":
"""This engine is designed to be simple and straightforward"""
"""This engine is designed to be simple and straightforward
Args:
input_dir (str): Path to the directory.
input_files (list): List of file paths to read
(Optional; overrides input_dir, exclude)
"""
documents = SimpleDirectoryReader(input_dir=input_dir, input_files=input_files).load_data()
service_context = ServiceContext.from_defaults(
llm=llm or get_default_llm(),
embed_model=embed_model or get_embedding(),
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
llm=llm or get_default_llm(),
)
index = VectorStoreIndex.from_documents(documents, service_context=service_context)
retriever = VectorIndexRetriever(index=index, similarity_top_k=similarity_top_k)
nodes = service_context.node_parser.get_nodes_from_documents(documents)
retriever = get_retriever(nodes, configs=retriever_configs, service_context=service_context)
rankers = get_rankers(configs=ranker_configs, service_context=service_context)
return SimpleEngine(retriever=retriever)
return SimpleEngine(retriever=retriever, node_postprocessors=rankers)
async def asearch(self, content: str, **kwargs) -> str:
"""Inplement tools.SearchInterface"""
return await self.aquery(content)
async def aretrieve(self, query: QueryType) -> list[NodeWithScore]:
"""Allow query to be str"""
query_bundle = QueryBundle(query) if isinstance(query, str) else query
return await super().aretrieve(query_bundle)

View file

@ -0,0 +1,34 @@
"""init"""
from metagpt.rag.schema import RankerConfig, LLMRankerConfig
from llama_index import ServiceContext
from llama_index.postprocessor import LLMRerank
from llama_index.postprocessor.types import BaseNodePostprocessor
def get_rankers(
configs: list[RankerConfig] = None, service_context: ServiceContext = None
) -> list[BaseNodePostprocessor]:
if not configs:
return [_default_ranker(service_context)]
return [_get_ranker(config, service_context) for config in configs]
def _default_ranker(service_context: ServiceContext = None):
return LLMRerank(top_n=LLMRankerConfig().top_n, service_context=service_context)
def _get_ranker(config: RankerConfig, service_context: ServiceContext = None):
ranker_factory = {
LLMRankerConfig: _create_llm_ranker,
}
create_func = ranker_factory.get(type(config))
if create_func:
return create_func(config, service_context)
raise ValueError(f"Unknown ranker config: {config}")
def _create_llm_ranker(config, service_context=None):
return LLMRerank(top_n=config.top_n, service_context=service_context)

View file

@ -1,4 +1,55 @@
"""init"""
from metagpt.rag.retrievers.hybrid import SimpleHybridRetriever
__all__ = ["SimpleHybridRetriever", "get_retriever"]
__all__ = ["SimpleHybridRetriever"]
from llama_index import (
ServiceContext,
StorageContext,
VectorStoreIndex,
)
from llama_index.retrievers import BaseRetriever, BM25Retriever, VectorIndexRetriever
from llama_index.schema import BaseNode
from llama_index.vector_stores.faiss import FaissVectorStore
from metagpt.rag.retrievers.hybrid import SimpleHybridRetriever
from metagpt.rag.schema import RetrieverConfig, FAISSRetrieverConfig, BM25RetrieverConfig
import faiss
def get_retriever(
nodes: list[BaseNode], configs: list[RetrieverConfig] = None, service_context: ServiceContext = None
) -> BaseRetriever:
if not configs:
return _default_retriever(nodes, service_context)
retrivers = [_get_retriever(nodes, config, service_context) for config in configs]
return SimpleHybridRetriever(*retrivers, service_context=service_context) if len(retrivers) > 1 else retrivers[0]
def _default_retriever(nodes: list[BaseNode], service_context: ServiceContext = None) -> BaseRetriever:
return VectorStoreIndex(nodes=nodes, service_context=service_context).as_retriever()
def _get_retriever(
nodes: list[BaseNode], config: RetrieverConfig, service_context: ServiceContext = None
) -> BaseRetriever:
retriever_factory = {
FAISSRetrieverConfig: _create_faiss_retriever,
BM25RetrieverConfig: _create_bm25_retriever,
}
create_func = retriever_factory.get(type(config))
if create_func:
return create_func(nodes, config, service_context)
raise ValueError(f"Unknown retriever config: {config}")
def _create_faiss_retriever(nodes: list[BaseNode], config: RetrieverConfig, service_context: ServiceContext):
vector_store = FaissVectorStore(faiss_index=faiss.IndexFlatL2(config.dimensions))
storage_context = StorageContext.from_defaults(vector_store=vector_store)
vector_index = VectorStoreIndex(nodes=nodes, storage_context=storage_context, service_context=service_context)
return VectorIndexRetriever(index=vector_index, similarity_top_k=config.similarity_top_k)
def _create_bm25_retriever(nodes: list[BaseNode], config: RetrieverConfig, service_context: ServiceContext = None):
return BM25Retriever.from_defaults(**config.model_dump(), nodes=nodes)

View file

@ -1,4 +1,5 @@
"""Hybrid retriever."""
from llama_index import ServiceContext
from llama_index.schema import QueryType
from metagpt.rag.retrievers.base import RAGRetriever
@ -9,8 +10,9 @@ class SimpleHybridRetriever(RAGRetriever):
SimpleHybridRetriever is a composite retriever that aggregates search results from multiple retrievers.
"""
def __init__(self, *retrievers):
def __init__(self, *retrievers, service_context: ServiceContext = None):
self.retrievers: list[RAGRetriever] = retrievers
self.service_context = service_context
super().__init__()
async def _aretrieve(self, query: QueryType, **kwargs):

23
metagpt/rag/schema.py Normal file
View file

@ -0,0 +1,23 @@
"""Retriever schemas"""
from pydantic import BaseModel
class RetrieverConfig(BaseModel):
similarity_top_k: int = 5
class FAISSRetrieverConfig(RetrieverConfig):
dimensions: int = 1536
class BM25RetrieverConfig(RetrieverConfig):
...
class RankerConfig(BaseModel):
top_n: int = 5
class LLMRankerConfig(RankerConfig):
...