From 191a86f93e0c448b40db201f5e4f697d29737e8c Mon Sep 17 00:00:00 2001 From: seehi <6580@pm.me> Date: Thu, 21 Mar 2024 12:04:06 +0800 Subject: [PATCH] rag add es --- examples/rag_pipeline.py | 74 +++++++++++++++++++---- metagpt/rag/factories/index.py | 47 +++++++++----- metagpt/rag/factories/retriever.py | 17 +++++- metagpt/rag/retrievers/es_retriever.py | 17 ++++++ metagpt/rag/retrievers/faiss_retriever.py | 2 +- metagpt/rag/schema.py | 28 ++++++++- requirements.txt | 3 + 7 files changed, 157 insertions(+), 31 deletions(-) create mode 100644 metagpt/rag/retrievers/es_retriever.py diff --git a/examples/rag_pipeline.py b/examples/rag_pipeline.py index 5a313d7bb..ae6e7b7bc 100644 --- a/examples/rag_pipeline.py +++ b/examples/rag_pipeline.py @@ -1,6 +1,7 @@ """RAG pipeline""" import asyncio +from functools import wraps from pydantic import BaseModel @@ -11,6 +12,9 @@ from metagpt.rag.schema import ( BM25RetrieverConfig, ChromaIndexConfig, ChromaRetrieverConfig, + ElasticsearchIndexConfig, + ElasticsearchRetrieverConfig, + ElasticsearchStoreConfig, FAISSRetrieverConfig, LLMRankerConfig, ) @@ -24,6 +28,17 @@ TRAVEL_QUESTION = "What does Bob like?" LLM_TIP = "If you not sure, just answer I don't know." +def catch_exception(func): + @wraps(func) + async def wrapper(*args, **kwargs): + try: + return await func(*args, **kwargs) + except Exception as e: + logger.error(f"{func.__name__} exception: {e}") + + return wrapper + + class Player(BaseModel): """To demonstrate rag add objs.""" @@ -39,12 +54,22 @@ class Player(BaseModel): class RAGExample: """Show how to use RAG.""" - def __init__(self): - self.engine = SimpleEngine.from_docs( - input_files=[DOC_PATH], - retriever_configs=[FAISSRetrieverConfig(), BM25RetrieverConfig()], - ranker_configs=[LLMRankerConfig()], - ) + def __init__(self, engine: SimpleEngine = None): + self._engine = engine + + @property + def engine(self): + if not self._engine: + self._engine = SimpleEngine.from_docs( + input_files=[DOC_PATH], + retriever_configs=[FAISSRetrieverConfig(), BM25RetrieverConfig()], + ranker_configs=[LLMRankerConfig()], + ) + return self._engine + + @engine.setter + def engine(self, value: SimpleEngine): + self._engine = value async def run_pipeline(self, question=QUESTION, print_title=True): """This example run rag pipeline, use faiss&bm25 retriever and llm ranker, will print something like: @@ -97,6 +122,7 @@ class RAGExample: self.engine.add_docs([travel_filepath]) await self.run_pipeline(question=travel_question, print_title=False) + @catch_exception async def add_objects(self, print_title=True): """This example show how to add objects. @@ -154,20 +180,43 @@ class RAGExample: """ self._print_title("Init And Query ChromaDB") - # save index + # 1.save index output_dir = DATA_PATH / "rag" SimpleEngine.from_docs( input_files=[TRAVEL_DOC_PATH], retriever_configs=[ChromaRetrieverConfig(persist_path=output_dir)], ) - # load index - engine = SimpleEngine.from_index( - index_config=ChromaIndexConfig(persist_path=output_dir), + # 2.load index + engine = SimpleEngine.from_index(index_config=ChromaIndexConfig(persist_path=output_dir)) + + # 3.query + answer = await engine.aquery(TRAVEL_QUESTION) + self._print_query_result(answer) + + @catch_exception + async def init_and_query_es(self): + """This example show how to use es. how to save and load index. will print something like: + + Query Result: + Bob likes traveling. + + If `Unclosed client session`, it's llamaindex elasticsearch problem, maybe fixed later. + """ + self._print_title("Init And Query Elasticsearch") + + # 1.create es index and save docs + store_config = ElasticsearchStoreConfig(index_name="travel", es_url="http://127.0.0.1:9200") + engine = SimpleEngine.from_docs( + input_files=[TRAVEL_DOC_PATH], + retriever_configs=[ElasticsearchRetrieverConfig(store_config=store_config)], ) - # query - answer = engine.query(TRAVEL_QUESTION) + # 2.load index + engine = SimpleEngine.from_index(index_config=ElasticsearchIndexConfig(store_config=store_config)) + + # 3.query + answer = await engine.aquery(TRAVEL_QUESTION) self._print_query_result(answer) @staticmethod @@ -205,6 +254,7 @@ async def main(): await e.add_objects() await e.init_objects() await e.init_and_query_chromadb() + await e.init_and_query_es() if __name__ == "__main__": diff --git a/metagpt/rag/factories/index.py b/metagpt/rag/factories/index.py index 6aad695e7..5ab7992a0 100644 --- a/metagpt/rag/factories/index.py +++ b/metagpt/rag/factories/index.py @@ -4,6 +4,8 @@ import chromadb from llama_index.core import StorageContext, VectorStoreIndex, load_index_from_storage from llama_index.core.embeddings import BaseEmbedding from llama_index.core.indices.base import BaseIndex +from llama_index.core.vector_stores.types import BasePydanticVectorStore +from llama_index.vector_stores.elasticsearch import ElasticsearchStore from llama_index.vector_stores.faiss import FaissVectorStore from metagpt.rag.factories.base import ConfigBasedFactory @@ -11,6 +13,7 @@ from metagpt.rag.schema import ( BaseIndexConfig, BM25IndexConfig, ChromaIndexConfig, + ElasticsearchIndexConfig, FAISSIndexConfig, ) from metagpt.rag.vector_stores.chroma import ChromaVectorStore @@ -22,6 +25,7 @@ class RAGIndexFactory(ConfigBasedFactory): FAISSIndexConfig: self._create_faiss, ChromaIndexConfig: self._create_chroma, BM25IndexConfig: self._create_bm25, + ElasticsearchIndexConfig: self._create_es, } super().__init__(creators) @@ -30,31 +34,44 @@ class RAGIndexFactory(ConfigBasedFactory): return super().get_instance(config, **kwargs) def _create_faiss(self, config: FAISSIndexConfig, **kwargs) -> VectorStoreIndex: - embed_model = self._extract_embed_model(config, **kwargs) - vector_store = FaissVectorStore.from_persist_dir(str(config.persist_path)) storage_context = StorageContext.from_defaults(vector_store=vector_store, persist_dir=config.persist_path) - index = load_index_from_storage(storage_context=storage_context, embed_model=embed_model) - return index + + return self._index_from_storage(storage_context=storage_context, config=config, **kwargs) + + def _create_bm25(self, config: BM25IndexConfig, **kwargs) -> VectorStoreIndex: + storage_context = StorageContext.from_defaults(persist_dir=config.persist_path) + + return self._index_from_storage(storage_context=storage_context, config=config, **kwargs) def _create_chroma(self, config: ChromaIndexConfig, **kwargs) -> VectorStoreIndex: - embed_model = self._extract_embed_model(config, **kwargs) - db = chromadb.PersistentClient(str(config.persist_path)) chroma_collection = db.get_or_create_collection(config.collection_name) vector_store = ChromaVectorStore(chroma_collection=chroma_collection) - index = VectorStoreIndex.from_vector_store( - vector_store, - embed_model=embed_model, - ) - return index - def _create_bm25(self, config: BM25IndexConfig, **kwargs) -> VectorStoreIndex: + return self._index_from_vector_store(vector_store=vector_store, config=config, **kwargs) + + def _create_es(self, config: ElasticsearchIndexConfig, **kwargs) -> VectorStoreIndex: + vector_store = ElasticsearchStore(**config.store_config.model_dump()) + + return self._index_from_vector_store(vector_store=vector_store, config=config, **kwargs) + + def _index_from_storage( + self, storage_context: StorageContext, config: BaseIndexConfig, **kwargs + ) -> VectorStoreIndex: embed_model = self._extract_embed_model(config, **kwargs) - storage_context = StorageContext.from_defaults(persist_dir=config.persist_path) - index = load_index_from_storage(storage_context=storage_context, embed_model=embed_model) - return index + return load_index_from_storage(storage_context=storage_context, embed_model=embed_model) + + def _index_from_vector_store( + self, vector_store: BasePydanticVectorStore, config: BaseIndexConfig, **kwargs + ) -> VectorStoreIndex: + embed_model = self._extract_embed_model(config, **kwargs) + + return VectorStoreIndex.from_vector_store( + vector_store=vector_store, + embed_model=embed_model, + ) def _extract_embed_model(self, config, **kwargs) -> BaseEmbedding: return self._val_from_config_or_kwargs("embed_model", config, **kwargs) diff --git a/metagpt/rag/factories/retriever.py b/metagpt/rag/factories/retriever.py index ba48c753e..47ceadf00 100644 --- a/metagpt/rag/factories/retriever.py +++ b/metagpt/rag/factories/retriever.py @@ -6,18 +6,21 @@ import chromadb import faiss from llama_index.core import StorageContext, VectorStoreIndex from llama_index.core.vector_stores.types import BasePydanticVectorStore +from llama_index.vector_stores.elasticsearch import ElasticsearchStore from llama_index.vector_stores.faiss import FaissVectorStore from metagpt.rag.factories.base import ConfigBasedFactory from metagpt.rag.retrievers.base import RAGRetriever from metagpt.rag.retrievers.bm25_retriever import DynamicBM25Retriever from metagpt.rag.retrievers.chroma_retriever import ChromaRetriever +from metagpt.rag.retrievers.es_retriever import ElasticsearchRetriever from metagpt.rag.retrievers.faiss_retriever import FAISSRetriever from metagpt.rag.retrievers.hybrid_retriever import SimpleHybridRetriever from metagpt.rag.schema import ( BaseRetrieverConfig, BM25RetrieverConfig, ChromaRetrieverConfig, + ElasticsearchRetrieverConfig, FAISSRetrieverConfig, IndexRetrieverConfig, ) @@ -32,6 +35,7 @@ class RetrieverFactory(ConfigBasedFactory): FAISSRetrieverConfig: self._create_faiss_retriever, BM25RetrieverConfig: self._create_bm25_retriever, ChromaRetrieverConfig: self._create_chroma_retriever, + ElasticsearchRetrieverConfig: self._create_es_retriever, } super().__init__(creators) @@ -53,20 +57,29 @@ class RetrieverFactory(ConfigBasedFactory): def _create_faiss_retriever(self, config: FAISSRetrieverConfig, **kwargs) -> FAISSRetriever: vector_store = FaissVectorStore(faiss_index=faiss.IndexFlatL2(config.dimensions)) config.index = self._build_index_from_vector_store(config, vector_store, **kwargs) + return FAISSRetriever(**config.model_dump()) def _create_bm25_retriever(self, config: BM25RetrieverConfig, **kwargs) -> DynamicBM25Retriever: config.index = copy.deepcopy(self._extract_index(config, **kwargs)) - nodes = list(config.index.docstore.docs.values()) - return DynamicBM25Retriever(nodes=nodes, **config.model_dump()) + + return DynamicBM25Retriever(nodes=list(config.index.docstore.docs.values()), **config.model_dump()) def _create_chroma_retriever(self, config: ChromaRetrieverConfig, **kwargs) -> ChromaRetriever: db = chromadb.PersistentClient(path=str(config.persist_path)) chroma_collection = db.get_or_create_collection(config.collection_name) + vector_store = ChromaVectorStore(chroma_collection=chroma_collection) config.index = self._build_index_from_vector_store(config, vector_store, **kwargs) + return ChromaRetriever(**config.model_dump()) + def _create_es_retriever(self, config: ElasticsearchRetrieverConfig, **kwargs) -> ElasticsearchRetriever: + vector_store = ElasticsearchStore(**config.store_config.model_dump()) + config.index = self._build_index_from_vector_store(config, vector_store, **kwargs) + + return ElasticsearchRetriever(**config.model_dump()) + def _extract_index(self, config: BaseRetrieverConfig = None, **kwargs) -> VectorStoreIndex: return self._val_from_config_or_kwargs("index", config, **kwargs) diff --git a/metagpt/rag/retrievers/es_retriever.py b/metagpt/rag/retrievers/es_retriever.py new file mode 100644 index 000000000..a1a0a6138 --- /dev/null +++ b/metagpt/rag/retrievers/es_retriever.py @@ -0,0 +1,17 @@ +"""Elasticsearch retriever.""" + +from llama_index.core.retrievers import VectorIndexRetriever +from llama_index.core.schema import BaseNode + + +class ElasticsearchRetriever(VectorIndexRetriever): + """Elasticsearch retriever.""" + + def add_nodes(self, nodes: list[BaseNode], **kwargs) -> None: + """Support add nodes.""" + self._index.insert_nodes(nodes, **kwargs) + + def persist(self, persist_dir: str, **kwargs) -> None: + """Support persist. + + Elasticsearch automatically saves, so there is no need to implement.""" diff --git a/metagpt/rag/retrievers/faiss_retriever.py b/metagpt/rag/retrievers/faiss_retriever.py index 7e543cce2..80b409292 100644 --- a/metagpt/rag/retrievers/faiss_retriever.py +++ b/metagpt/rag/retrievers/faiss_retriever.py @@ -8,7 +8,7 @@ class FAISSRetriever(VectorIndexRetriever): """FAISS retriever.""" def add_nodes(self, nodes: list[BaseNode], **kwargs) -> None: - """Support add nodes""" + """Support add nodes.""" self._index.insert_nodes(nodes, **kwargs) def persist(self, persist_dir: str, **kwargs) -> None: diff --git a/metagpt/rag/schema.py b/metagpt/rag/schema.py index cae1c2979..e98a6fc89 100644 --- a/metagpt/rag/schema.py +++ b/metagpt/rag/schema.py @@ -6,6 +6,7 @@ from typing import Any, Union from llama_index.core.embeddings import BaseEmbedding from llama_index.core.indices.base import BaseIndex from llama_index.core.schema import TextNode +from llama_index.core.vector_stores.types import VectorStoreQueryMode from pydantic import BaseModel, ConfigDict, Field, PrivateAttr from metagpt.rag.interface import RAGObject @@ -46,6 +47,24 @@ class ChromaRetrieverConfig(IndexRetrieverConfig): collection_name: str = Field(default="metagpt", description="The name of the collection.") +class ElasticsearchStoreConfig(BaseModel): + index_name: str = Field(default="metagpt", description="Name of the Elasticsearch index.") + es_url: str = Field(default=None, description="Elasticsearch URL.") + es_cloud_id: str = Field(default=None, description="Elasticsearch cloud ID.") + es_api_key: str = Field(default=None, description="Elasticsearch API key.") + es_user: str = Field(default=None, description="Elasticsearch username.") + es_password: str = Field(default=None, description="Elasticsearch password.") + batch_size: int = Field(default=200, description="Batch size for bulk indexing.") + distance_strategy: str = Field(default="COSINE", description="Distance strategy to use for similarity search.") + + +class ElasticsearchRetrieverConfig(IndexRetrieverConfig): + """Config for Elasticsearch-based retrievers.""" + + store_config: ElasticsearchStoreConfig = Field(..., description="ElasticsearchStore config.") + vector_store_query_mode: VectorStoreQueryMode = VectorStoreQueryMode.DEFAULT + + class BaseRankerConfig(BaseModel): """Common config for rankers. @@ -53,7 +72,6 @@ class BaseRankerConfig(BaseModel): """ model_config = ConfigDict(arbitrary_types_allowed=True) - top_n: int = Field(default=5, description="The number of top results to return.") @@ -72,6 +90,7 @@ class BaseIndexConfig(BaseModel): If add new subconfig, it is necessary to add the corresponding instance implementation in rag.factories.index. """ + model_config = ConfigDict(arbitrary_types_allowed=True) persist_path: Union[str, Path] = Field(description="The directory of saved data.") @@ -97,6 +116,13 @@ class BM25IndexConfig(BaseIndexConfig): _no_embedding: bool = PrivateAttr(default=True) +class ElasticsearchIndexConfig(VectorIndexConfig): + """Config for es-based index.""" + + store_config: ElasticsearchStoreConfig = Field(..., description="ElasticsearchStore config.") + persist_path: Union[str, Path] = "" + + class ObjectNodeMetadata(BaseModel): """Metadata of ObjectNode.""" diff --git a/requirements.txt b/requirements.txt index 326fa8bb9..3e545d146 100644 --- a/requirements.txt +++ b/requirements.txt @@ -17,6 +17,7 @@ llama-index-llms-azure-openai==0.1.4 llama-index-readers-file==0.1.4 llama-index-retrievers-bm25==0.1.3 llama-index-vector-stores-faiss==0.1.1 +llama-index-vector-stores-elasticsearch==0.1.5 chromadb==0.4.23 loguru==0.6.0 meilisearch==0.21.0 @@ -76,3 +77,5 @@ Pillow imap_tools==1.5.0 # Used by metagpt/tools/libs/email_login.py qianfan==0.3.2 dashscope==1.14.1 +rank-bm25==0.2.2 # for tool recommendation +jieba==0.42.1 # for tool recommendation