From 191a86f93e0c448b40db201f5e4f697d29737e8c Mon Sep 17 00:00:00 2001 From: seehi <6580@pm.me> Date: Thu, 21 Mar 2024 12:04:06 +0800 Subject: [PATCH 01/12] rag add es --- examples/rag_pipeline.py | 74 +++++++++++++++++++---- metagpt/rag/factories/index.py | 47 +++++++++----- metagpt/rag/factories/retriever.py | 17 +++++- metagpt/rag/retrievers/es_retriever.py | 17 ++++++ metagpt/rag/retrievers/faiss_retriever.py | 2 +- metagpt/rag/schema.py | 28 ++++++++- requirements.txt | 3 + 7 files changed, 157 insertions(+), 31 deletions(-) create mode 100644 metagpt/rag/retrievers/es_retriever.py diff --git a/examples/rag_pipeline.py b/examples/rag_pipeline.py index 5a313d7bb..ae6e7b7bc 100644 --- a/examples/rag_pipeline.py +++ b/examples/rag_pipeline.py @@ -1,6 +1,7 @@ """RAG pipeline""" import asyncio +from functools import wraps from pydantic import BaseModel @@ -11,6 +12,9 @@ from metagpt.rag.schema import ( BM25RetrieverConfig, ChromaIndexConfig, ChromaRetrieverConfig, + ElasticsearchIndexConfig, + ElasticsearchRetrieverConfig, + ElasticsearchStoreConfig, FAISSRetrieverConfig, LLMRankerConfig, ) @@ -24,6 +28,17 @@ TRAVEL_QUESTION = "What does Bob like?" LLM_TIP = "If you not sure, just answer I don't know." +def catch_exception(func): + @wraps(func) + async def wrapper(*args, **kwargs): + try: + return await func(*args, **kwargs) + except Exception as e: + logger.error(f"{func.__name__} exception: {e}") + + return wrapper + + class Player(BaseModel): """To demonstrate rag add objs.""" @@ -39,12 +54,22 @@ class Player(BaseModel): class RAGExample: """Show how to use RAG.""" - def __init__(self): - self.engine = SimpleEngine.from_docs( - input_files=[DOC_PATH], - retriever_configs=[FAISSRetrieverConfig(), BM25RetrieverConfig()], - ranker_configs=[LLMRankerConfig()], - ) + def __init__(self, engine: SimpleEngine = None): + self._engine = engine + + @property + def engine(self): + if not self._engine: + self._engine = SimpleEngine.from_docs( + input_files=[DOC_PATH], + retriever_configs=[FAISSRetrieverConfig(), BM25RetrieverConfig()], + ranker_configs=[LLMRankerConfig()], + ) + return self._engine + + @engine.setter + def engine(self, value: SimpleEngine): + self._engine = value async def run_pipeline(self, question=QUESTION, print_title=True): """This example run rag pipeline, use faiss&bm25 retriever and llm ranker, will print something like: @@ -97,6 +122,7 @@ class RAGExample: self.engine.add_docs([travel_filepath]) await self.run_pipeline(question=travel_question, print_title=False) + @catch_exception async def add_objects(self, print_title=True): """This example show how to add objects. @@ -154,20 +180,43 @@ class RAGExample: """ self._print_title("Init And Query ChromaDB") - # save index + # 1.save index output_dir = DATA_PATH / "rag" SimpleEngine.from_docs( input_files=[TRAVEL_DOC_PATH], retriever_configs=[ChromaRetrieverConfig(persist_path=output_dir)], ) - # load index - engine = SimpleEngine.from_index( - index_config=ChromaIndexConfig(persist_path=output_dir), + # 2.load index + engine = SimpleEngine.from_index(index_config=ChromaIndexConfig(persist_path=output_dir)) + + # 3.query + answer = await engine.aquery(TRAVEL_QUESTION) + self._print_query_result(answer) + + @catch_exception + async def init_and_query_es(self): + """This example show how to use es. how to save and load index. will print something like: + + Query Result: + Bob likes traveling. + + If `Unclosed client session`, it's llamaindex elasticsearch problem, maybe fixed later. + """ + self._print_title("Init And Query Elasticsearch") + + # 1.create es index and save docs + store_config = ElasticsearchStoreConfig(index_name="travel", es_url="http://127.0.0.1:9200") + engine = SimpleEngine.from_docs( + input_files=[TRAVEL_DOC_PATH], + retriever_configs=[ElasticsearchRetrieverConfig(store_config=store_config)], ) - # query - answer = engine.query(TRAVEL_QUESTION) + # 2.load index + engine = SimpleEngine.from_index(index_config=ElasticsearchIndexConfig(store_config=store_config)) + + # 3.query + answer = await engine.aquery(TRAVEL_QUESTION) self._print_query_result(answer) @staticmethod @@ -205,6 +254,7 @@ async def main(): await e.add_objects() await e.init_objects() await e.init_and_query_chromadb() + await e.init_and_query_es() if __name__ == "__main__": diff --git a/metagpt/rag/factories/index.py b/metagpt/rag/factories/index.py index 6aad695e7..5ab7992a0 100644 --- a/metagpt/rag/factories/index.py +++ b/metagpt/rag/factories/index.py @@ -4,6 +4,8 @@ import chromadb from llama_index.core import StorageContext, VectorStoreIndex, load_index_from_storage from llama_index.core.embeddings import BaseEmbedding from llama_index.core.indices.base import BaseIndex +from llama_index.core.vector_stores.types import BasePydanticVectorStore +from llama_index.vector_stores.elasticsearch import ElasticsearchStore from llama_index.vector_stores.faiss import FaissVectorStore from metagpt.rag.factories.base import ConfigBasedFactory @@ -11,6 +13,7 @@ from metagpt.rag.schema import ( BaseIndexConfig, BM25IndexConfig, ChromaIndexConfig, + ElasticsearchIndexConfig, FAISSIndexConfig, ) from metagpt.rag.vector_stores.chroma import ChromaVectorStore @@ -22,6 +25,7 @@ class RAGIndexFactory(ConfigBasedFactory): FAISSIndexConfig: self._create_faiss, ChromaIndexConfig: self._create_chroma, BM25IndexConfig: self._create_bm25, + ElasticsearchIndexConfig: self._create_es, } super().__init__(creators) @@ -30,31 +34,44 @@ class RAGIndexFactory(ConfigBasedFactory): return super().get_instance(config, **kwargs) def _create_faiss(self, config: FAISSIndexConfig, **kwargs) -> VectorStoreIndex: - embed_model = self._extract_embed_model(config, **kwargs) - vector_store = FaissVectorStore.from_persist_dir(str(config.persist_path)) storage_context = StorageContext.from_defaults(vector_store=vector_store, persist_dir=config.persist_path) - index = load_index_from_storage(storage_context=storage_context, embed_model=embed_model) - return index + + return self._index_from_storage(storage_context=storage_context, config=config, **kwargs) + + def _create_bm25(self, config: BM25IndexConfig, **kwargs) -> VectorStoreIndex: + storage_context = StorageContext.from_defaults(persist_dir=config.persist_path) + + return self._index_from_storage(storage_context=storage_context, config=config, **kwargs) def _create_chroma(self, config: ChromaIndexConfig, **kwargs) -> VectorStoreIndex: - embed_model = self._extract_embed_model(config, **kwargs) - db = chromadb.PersistentClient(str(config.persist_path)) chroma_collection = db.get_or_create_collection(config.collection_name) vector_store = ChromaVectorStore(chroma_collection=chroma_collection) - index = VectorStoreIndex.from_vector_store( - vector_store, - embed_model=embed_model, - ) - return index - def _create_bm25(self, config: BM25IndexConfig, **kwargs) -> VectorStoreIndex: + return self._index_from_vector_store(vector_store=vector_store, config=config, **kwargs) + + def _create_es(self, config: ElasticsearchIndexConfig, **kwargs) -> VectorStoreIndex: + vector_store = ElasticsearchStore(**config.store_config.model_dump()) + + return self._index_from_vector_store(vector_store=vector_store, config=config, **kwargs) + + def _index_from_storage( + self, storage_context: StorageContext, config: BaseIndexConfig, **kwargs + ) -> VectorStoreIndex: embed_model = self._extract_embed_model(config, **kwargs) - storage_context = StorageContext.from_defaults(persist_dir=config.persist_path) - index = load_index_from_storage(storage_context=storage_context, embed_model=embed_model) - return index + return load_index_from_storage(storage_context=storage_context, embed_model=embed_model) + + def _index_from_vector_store( + self, vector_store: BasePydanticVectorStore, config: BaseIndexConfig, **kwargs + ) -> VectorStoreIndex: + embed_model = self._extract_embed_model(config, **kwargs) + + return VectorStoreIndex.from_vector_store( + vector_store=vector_store, + embed_model=embed_model, + ) def _extract_embed_model(self, config, **kwargs) -> BaseEmbedding: return self._val_from_config_or_kwargs("embed_model", config, **kwargs) diff --git a/metagpt/rag/factories/retriever.py b/metagpt/rag/factories/retriever.py index ba48c753e..47ceadf00 100644 --- a/metagpt/rag/factories/retriever.py +++ b/metagpt/rag/factories/retriever.py @@ -6,18 +6,21 @@ import chromadb import faiss from llama_index.core import StorageContext, VectorStoreIndex from llama_index.core.vector_stores.types import BasePydanticVectorStore +from llama_index.vector_stores.elasticsearch import ElasticsearchStore from llama_index.vector_stores.faiss import FaissVectorStore from metagpt.rag.factories.base import ConfigBasedFactory from metagpt.rag.retrievers.base import RAGRetriever from metagpt.rag.retrievers.bm25_retriever import DynamicBM25Retriever from metagpt.rag.retrievers.chroma_retriever import ChromaRetriever +from metagpt.rag.retrievers.es_retriever import ElasticsearchRetriever from metagpt.rag.retrievers.faiss_retriever import FAISSRetriever from metagpt.rag.retrievers.hybrid_retriever import SimpleHybridRetriever from metagpt.rag.schema import ( BaseRetrieverConfig, BM25RetrieverConfig, ChromaRetrieverConfig, + ElasticsearchRetrieverConfig, FAISSRetrieverConfig, IndexRetrieverConfig, ) @@ -32,6 +35,7 @@ class RetrieverFactory(ConfigBasedFactory): FAISSRetrieverConfig: self._create_faiss_retriever, BM25RetrieverConfig: self._create_bm25_retriever, ChromaRetrieverConfig: self._create_chroma_retriever, + ElasticsearchRetrieverConfig: self._create_es_retriever, } super().__init__(creators) @@ -53,20 +57,29 @@ class RetrieverFactory(ConfigBasedFactory): def _create_faiss_retriever(self, config: FAISSRetrieverConfig, **kwargs) -> FAISSRetriever: vector_store = FaissVectorStore(faiss_index=faiss.IndexFlatL2(config.dimensions)) config.index = self._build_index_from_vector_store(config, vector_store, **kwargs) + return FAISSRetriever(**config.model_dump()) def _create_bm25_retriever(self, config: BM25RetrieverConfig, **kwargs) -> DynamicBM25Retriever: config.index = copy.deepcopy(self._extract_index(config, **kwargs)) - nodes = list(config.index.docstore.docs.values()) - return DynamicBM25Retriever(nodes=nodes, **config.model_dump()) + + return DynamicBM25Retriever(nodes=list(config.index.docstore.docs.values()), **config.model_dump()) def _create_chroma_retriever(self, config: ChromaRetrieverConfig, **kwargs) -> ChromaRetriever: db = chromadb.PersistentClient(path=str(config.persist_path)) chroma_collection = db.get_or_create_collection(config.collection_name) + vector_store = ChromaVectorStore(chroma_collection=chroma_collection) config.index = self._build_index_from_vector_store(config, vector_store, **kwargs) + return ChromaRetriever(**config.model_dump()) + def _create_es_retriever(self, config: ElasticsearchRetrieverConfig, **kwargs) -> ElasticsearchRetriever: + vector_store = ElasticsearchStore(**config.store_config.model_dump()) + config.index = self._build_index_from_vector_store(config, vector_store, **kwargs) + + return ElasticsearchRetriever(**config.model_dump()) + def _extract_index(self, config: BaseRetrieverConfig = None, **kwargs) -> VectorStoreIndex: return self._val_from_config_or_kwargs("index", config, **kwargs) diff --git a/metagpt/rag/retrievers/es_retriever.py b/metagpt/rag/retrievers/es_retriever.py new file mode 100644 index 000000000..a1a0a6138 --- /dev/null +++ b/metagpt/rag/retrievers/es_retriever.py @@ -0,0 +1,17 @@ +"""Elasticsearch retriever.""" + +from llama_index.core.retrievers import VectorIndexRetriever +from llama_index.core.schema import BaseNode + + +class ElasticsearchRetriever(VectorIndexRetriever): + """Elasticsearch retriever.""" + + def add_nodes(self, nodes: list[BaseNode], **kwargs) -> None: + """Support add nodes.""" + self._index.insert_nodes(nodes, **kwargs) + + def persist(self, persist_dir: str, **kwargs) -> None: + """Support persist. + + Elasticsearch automatically saves, so there is no need to implement.""" diff --git a/metagpt/rag/retrievers/faiss_retriever.py b/metagpt/rag/retrievers/faiss_retriever.py index 7e543cce2..80b409292 100644 --- a/metagpt/rag/retrievers/faiss_retriever.py +++ b/metagpt/rag/retrievers/faiss_retriever.py @@ -8,7 +8,7 @@ class FAISSRetriever(VectorIndexRetriever): """FAISS retriever.""" def add_nodes(self, nodes: list[BaseNode], **kwargs) -> None: - """Support add nodes""" + """Support add nodes.""" self._index.insert_nodes(nodes, **kwargs) def persist(self, persist_dir: str, **kwargs) -> None: diff --git a/metagpt/rag/schema.py b/metagpt/rag/schema.py index cae1c2979..e98a6fc89 100644 --- a/metagpt/rag/schema.py +++ b/metagpt/rag/schema.py @@ -6,6 +6,7 @@ from typing import Any, Union from llama_index.core.embeddings import BaseEmbedding from llama_index.core.indices.base import BaseIndex from llama_index.core.schema import TextNode +from llama_index.core.vector_stores.types import VectorStoreQueryMode from pydantic import BaseModel, ConfigDict, Field, PrivateAttr from metagpt.rag.interface import RAGObject @@ -46,6 +47,24 @@ class ChromaRetrieverConfig(IndexRetrieverConfig): collection_name: str = Field(default="metagpt", description="The name of the collection.") +class ElasticsearchStoreConfig(BaseModel): + index_name: str = Field(default="metagpt", description="Name of the Elasticsearch index.") + es_url: str = Field(default=None, description="Elasticsearch URL.") + es_cloud_id: str = Field(default=None, description="Elasticsearch cloud ID.") + es_api_key: str = Field(default=None, description="Elasticsearch API key.") + es_user: str = Field(default=None, description="Elasticsearch username.") + es_password: str = Field(default=None, description="Elasticsearch password.") + batch_size: int = Field(default=200, description="Batch size for bulk indexing.") + distance_strategy: str = Field(default="COSINE", description="Distance strategy to use for similarity search.") + + +class ElasticsearchRetrieverConfig(IndexRetrieverConfig): + """Config for Elasticsearch-based retrievers.""" + + store_config: ElasticsearchStoreConfig = Field(..., description="ElasticsearchStore config.") + vector_store_query_mode: VectorStoreQueryMode = VectorStoreQueryMode.DEFAULT + + class BaseRankerConfig(BaseModel): """Common config for rankers. @@ -53,7 +72,6 @@ class BaseRankerConfig(BaseModel): """ model_config = ConfigDict(arbitrary_types_allowed=True) - top_n: int = Field(default=5, description="The number of top results to return.") @@ -72,6 +90,7 @@ class BaseIndexConfig(BaseModel): If add new subconfig, it is necessary to add the corresponding instance implementation in rag.factories.index. """ + model_config = ConfigDict(arbitrary_types_allowed=True) persist_path: Union[str, Path] = Field(description="The directory of saved data.") @@ -97,6 +116,13 @@ class BM25IndexConfig(BaseIndexConfig): _no_embedding: bool = PrivateAttr(default=True) +class ElasticsearchIndexConfig(VectorIndexConfig): + """Config for es-based index.""" + + store_config: ElasticsearchStoreConfig = Field(..., description="ElasticsearchStore config.") + persist_path: Union[str, Path] = "" + + class ObjectNodeMetadata(BaseModel): """Metadata of ObjectNode.""" diff --git a/requirements.txt b/requirements.txt index 326fa8bb9..3e545d146 100644 --- a/requirements.txt +++ b/requirements.txt @@ -17,6 +17,7 @@ llama-index-llms-azure-openai==0.1.4 llama-index-readers-file==0.1.4 llama-index-retrievers-bm25==0.1.3 llama-index-vector-stores-faiss==0.1.1 +llama-index-vector-stores-elasticsearch==0.1.5 chromadb==0.4.23 loguru==0.6.0 meilisearch==0.21.0 @@ -76,3 +77,5 @@ Pillow imap_tools==1.5.0 # Used by metagpt/tools/libs/email_login.py qianfan==0.3.2 dashscope==1.14.1 +rank-bm25==0.2.2 # for tool recommendation +jieba==0.42.1 # for tool recommendation From e53188f8981d7748343e902821b544a59170fd6b Mon Sep 17 00:00:00 2001 From: seehi <6580@pm.me> Date: Thu, 21 Mar 2024 16:39:53 +0800 Subject: [PATCH 02/12] fix potential pydantic ValidationError --- metagpt/rag/engines/simple.py | 4 +++- metagpt/rag/factories/llm.py | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/metagpt/rag/engines/simple.py b/metagpt/rag/engines/simple.py index 02f9ca7b1..5c5810308 100644 --- a/metagpt/rag/engines/simple.py +++ b/metagpt/rag/engines/simple.py @@ -130,10 +130,12 @@ class SimpleEngine(RetrieverQueryEngine): retriever_configs: Configuration for retrievers. If more than one config, will use SimpleHybridRetriever. ranker_configs: Configuration for rankers. """ + objs = objs or [] + retriever_configs = retriever_configs or [] + if not objs and any(isinstance(config, BM25RetrieverConfig) for config in retriever_configs): raise ValueError("In BM25RetrieverConfig, Objs must not be empty.") - objs = objs or [] nodes = [ObjectNode(text=obj.rag_key(), metadata=ObjectNode.get_obj_metadata(obj)) for obj in objs] index = VectorStoreIndex( nodes=nodes, diff --git a/metagpt/rag/factories/llm.py b/metagpt/rag/factories/llm.py index 1cdbab14d..17c499b76 100644 --- a/metagpt/rag/factories/llm.py +++ b/metagpt/rag/factories/llm.py @@ -33,7 +33,9 @@ class RAGLLM(CustomLLM): @property def metadata(self) -> LLMMetadata: """Get LLM metadata.""" - return LLMMetadata(context_window=self.context_window, num_output=self.num_output, model_name=self.model_name) + return LLMMetadata( + context_window=self.context_window, num_output=self.num_output, model_name=self.model_name or "unknown" + ) @llm_completion_callback() def complete(self, prompt: str, **kwargs: Any) -> CompletionResponse: From 6e30b42cc0ee343ce7f9a706632b4fac1c71744a Mon Sep 17 00:00:00 2001 From: seehi <6580@pm.me> Date: Thu, 21 Mar 2024 16:50:59 +0800 Subject: [PATCH 03/12] add FLAREEngine and ColbertRerank --- metagpt/rag/engines/__init__.py | 3 ++- metagpt/rag/engines/flare.py | 0 metagpt/rag/factories/ranker.py | 10 ++++++---- metagpt/rag/schema.py | 6 ++++++ requirements.txt | 1 + 5 files changed, 15 insertions(+), 5 deletions(-) create mode 100644 metagpt/rag/engines/flare.py diff --git a/metagpt/rag/engines/__init__.py b/metagpt/rag/engines/__init__.py index 373181384..93699db88 100644 --- a/metagpt/rag/engines/__init__.py +++ b/metagpt/rag/engines/__init__.py @@ -1,5 +1,6 @@ """Engines init""" from metagpt.rag.engines.simple import SimpleEngine +from metagpt.rag.engines.flare import FLAREEngine -__all__ = ["SimpleEngine"] +__all__ = ["SimpleEngine", "FLAREEngine"] diff --git a/metagpt/rag/engines/flare.py b/metagpt/rag/engines/flare.py new file mode 100644 index 000000000..e69de29bb diff --git a/metagpt/rag/factories/ranker.py b/metagpt/rag/factories/ranker.py index f05599e15..15dc55bf9 100644 --- a/metagpt/rag/factories/ranker.py +++ b/metagpt/rag/factories/ranker.py @@ -3,18 +3,17 @@ from llama_index.core.llms import LLM from llama_index.core.postprocessor import LLMRerank from llama_index.core.postprocessor.types import BaseNodePostprocessor +from llama_index.postprocessor.colbert_rerank import ColbertRerank from metagpt.rag.factories.base import ConfigBasedFactory -from metagpt.rag.schema import BaseRankerConfig, LLMRankerConfig +from metagpt.rag.schema import BaseRankerConfig, ColbertRerankConfig, LLMRankerConfig class RankerFactory(ConfigBasedFactory): """Modify creators for dynamically instance implementation.""" def __init__(self): - creators = { - LLMRankerConfig: self._create_llm_ranker, - } + creators = {LLMRankerConfig: self._create_llm_ranker, ColbertRerankConfig: self._create_colbert_ranker} super().__init__(creators) def get_rankers(self, configs: list[BaseRankerConfig] = None, **kwargs) -> list[BaseNodePostprocessor]: @@ -28,6 +27,9 @@ class RankerFactory(ConfigBasedFactory): config.llm = self._extract_llm(config, **kwargs) return LLMRerank(**config.model_dump()) + def _create_colbert_ranker(self, config: ColbertRerankConfig, **kwargs) -> LLMRerank: + return ColbertRerank(**config.model_dump()) + def _extract_llm(self, config: BaseRankerConfig = None, **kwargs) -> LLM: return self._val_from_config_or_kwargs("llm", config, **kwargs) diff --git a/metagpt/rag/schema.py b/metagpt/rag/schema.py index e98a6fc89..cacce3178 100644 --- a/metagpt/rag/schema.py +++ b/metagpt/rag/schema.py @@ -84,6 +84,12 @@ class LLMRankerConfig(BaseRankerConfig): ) +class ColbertRerankConfig(BaseRankerConfig): + model: str = Field(default="colbert-ir/colbertv2.0", description="Colbert model name.") + device: str = Field(default="cpu", description="Device to use for sentence transformer.") + keep_retrieval_score: bool = Field(default=False, description="Whether to keep the retrieval score in metadata.") + + class BaseIndexConfig(BaseModel): """Common config for index. diff --git a/requirements.txt b/requirements.txt index 3e545d146..9bcd2a45b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -18,6 +18,7 @@ llama-index-readers-file==0.1.4 llama-index-retrievers-bm25==0.1.3 llama-index-vector-stores-faiss==0.1.1 llama-index-vector-stores-elasticsearch==0.1.5 +llama-index-postprocessor-colbert-rerank==0.1.1 chromadb==0.4.23 loguru==0.6.0 meilisearch==0.21.0 From 73953c025d16ec99994f2262fa8cae9b6aa0f58c Mon Sep 17 00:00:00 2001 From: seehi <6580@pm.me> Date: Thu, 21 Mar 2024 17:11:18 +0800 Subject: [PATCH 04/12] add FLAREEngine and ColbertRerank --- metagpt/rag/engines/flare.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/metagpt/rag/engines/flare.py b/metagpt/rag/engines/flare.py index e69de29bb..3fd1bf84b 100644 --- a/metagpt/rag/engines/flare.py +++ b/metagpt/rag/engines/flare.py @@ -0,0 +1,9 @@ +"""FLARE Engine. + +Use llamaindex's FLAREInstructQueryEngine, which accepts other engines as parameters. +For example, Create a simple engine, and then pass it to FLAREEngine. +""" + +from llama_index.core.query_engine import ( # noqa: F401 + FLAREInstructQueryEngine as FLAREEngine, +) From 7c1c4b2a35659520e4f8e779acbeba54dd1cab91 Mon Sep 17 00:00:00 2001 From: seehi <6580@pm.me> Date: Thu, 21 Mar 2024 17:14:22 +0800 Subject: [PATCH 05/12] update comment --- metagpt/rag/engines/flare.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/metagpt/rag/engines/flare.py b/metagpt/rag/engines/flare.py index 3fd1bf84b..dc05bd3dd 100644 --- a/metagpt/rag/engines/flare.py +++ b/metagpt/rag/engines/flare.py @@ -1,6 +1,6 @@ """FLARE Engine. -Use llamaindex's FLAREInstructQueryEngine, which accepts other engines as parameters. +Use llamaindex's FLAREInstructQueryEngine as FLAREEngine, which accepts other engines as parameters. For example, Create a simple engine, and then pass it to FLAREEngine. """ From 34a3c1ad0753316188655f8b90fa1996a2f95523 Mon Sep 17 00:00:00 2001 From: seehi <6580@pm.me> Date: Fri, 22 Mar 2024 15:58:59 +0800 Subject: [PATCH 06/12] upgrade llama-index-vector-stores-elasticsearch --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 9bcd2a45b..6e84f4612 100644 --- a/requirements.txt +++ b/requirements.txt @@ -17,7 +17,7 @@ llama-index-llms-azure-openai==0.1.4 llama-index-readers-file==0.1.4 llama-index-retrievers-bm25==0.1.3 llama-index-vector-stores-faiss==0.1.1 -llama-index-vector-stores-elasticsearch==0.1.5 +llama-index-vector-stores-elasticsearch==0.1.6 llama-index-postprocessor-colbert-rerank==0.1.1 chromadb==0.4.23 loguru==0.6.0 From 092ef26425279f76318366e880246c85739940fb Mon Sep 17 00:00:00 2001 From: seehi <6580@pm.me> Date: Fri, 22 Mar 2024 17:30:35 +0800 Subject: [PATCH 07/12] support elasticsearch text only --- metagpt/rag/factories/base.py | 2 +- metagpt/rag/factories/index.py | 2 ++ metagpt/rag/factories/retriever.py | 2 ++ metagpt/rag/schema.py | 15 ++++++++++++++- 4 files changed, 19 insertions(+), 2 deletions(-) diff --git a/metagpt/rag/factories/base.py b/metagpt/rag/factories/base.py index 8f8155914..fbdfbf1a8 100644 --- a/metagpt/rag/factories/base.py +++ b/metagpt/rag/factories/base.py @@ -41,7 +41,7 @@ class ConfigBasedFactory(GenericFactory): if creator: return creator(key, **kwargs) - raise ValueError(f"Unknown config: {key}") + raise ValueError(f"Unknown config: `{type(key)}`, {key}") @staticmethod def _val_from_config_or_kwargs(key: str, config: object = None, **kwargs) -> Any: diff --git a/metagpt/rag/factories/index.py b/metagpt/rag/factories/index.py index 5ab7992a0..f200fc94f 100644 --- a/metagpt/rag/factories/index.py +++ b/metagpt/rag/factories/index.py @@ -14,6 +14,7 @@ from metagpt.rag.schema import ( BM25IndexConfig, ChromaIndexConfig, ElasticsearchIndexConfig, + ElasticsearchKeywordIndexConfig, FAISSIndexConfig, ) from metagpt.rag.vector_stores.chroma import ChromaVectorStore @@ -26,6 +27,7 @@ class RAGIndexFactory(ConfigBasedFactory): ChromaIndexConfig: self._create_chroma, BM25IndexConfig: self._create_bm25, ElasticsearchIndexConfig: self._create_es, + ElasticsearchKeywordIndexConfig: self._create_es, } super().__init__(creators) diff --git a/metagpt/rag/factories/retriever.py b/metagpt/rag/factories/retriever.py index 47ceadf00..a107d9573 100644 --- a/metagpt/rag/factories/retriever.py +++ b/metagpt/rag/factories/retriever.py @@ -20,6 +20,7 @@ from metagpt.rag.schema import ( BaseRetrieverConfig, BM25RetrieverConfig, ChromaRetrieverConfig, + ElasticsearchKeywordRetrieverConfig, ElasticsearchRetrieverConfig, FAISSRetrieverConfig, IndexRetrieverConfig, @@ -36,6 +37,7 @@ class RetrieverFactory(ConfigBasedFactory): BM25RetrieverConfig: self._create_bm25_retriever, ChromaRetrieverConfig: self._create_chroma_retriever, ElasticsearchRetrieverConfig: self._create_es_retriever, + ElasticsearchKeywordRetrieverConfig: self._create_es_retriever, } super().__init__(creators) diff --git a/metagpt/rag/schema.py b/metagpt/rag/schema.py index cacce3178..cb5f1aac0 100644 --- a/metagpt/rag/schema.py +++ b/metagpt/rag/schema.py @@ -59,12 +59,19 @@ class ElasticsearchStoreConfig(BaseModel): class ElasticsearchRetrieverConfig(IndexRetrieverConfig): - """Config for Elasticsearch-based retrievers.""" + """Config for Elasticsearch-based retrievers. Support both vector and text.""" store_config: ElasticsearchStoreConfig = Field(..., description="ElasticsearchStore config.") vector_store_query_mode: VectorStoreQueryMode = VectorStoreQueryMode.DEFAULT +class ElasticsearchKeywordRetrieverConfig(ElasticsearchRetrieverConfig): + """Config for Elasticsearch-based retrievers. Support text only.""" + + _no_embedding: bool = PrivateAttr(default=True) + vector_store_query_mode: VectorStoreQueryMode = VectorStoreQueryMode.TEXT_SEARCH + + class BaseRankerConfig(BaseModel): """Common config for rankers. @@ -129,6 +136,12 @@ class ElasticsearchIndexConfig(VectorIndexConfig): persist_path: Union[str, Path] = "" +class ElasticsearchKeywordIndexConfig(ElasticsearchIndexConfig): + """Config for es-based index. no embedding.""" + + _no_embedding: bool = PrivateAttr(default=True) + + class ObjectNodeMetadata(BaseModel): """Metadata of ObjectNode.""" From aaae00441b21945009f2594003a57a3b5e8bdee2 Mon Sep 17 00:00:00 2001 From: seehi <6580@pm.me> Date: Fri, 22 Mar 2024 18:05:32 +0800 Subject: [PATCH 08/12] use Literal to restrict vector_store_query_mode of ElasticsearchKeywordRetrieverConfig --- metagpt/rag/schema.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/metagpt/rag/schema.py b/metagpt/rag/schema.py index cb5f1aac0..0711f5c83 100644 --- a/metagpt/rag/schema.py +++ b/metagpt/rag/schema.py @@ -1,7 +1,7 @@ """RAG schemas.""" from pathlib import Path -from typing import Any, Union +from typing import Any, Literal, Union from llama_index.core.embeddings import BaseEmbedding from llama_index.core.indices.base import BaseIndex @@ -62,14 +62,18 @@ class ElasticsearchRetrieverConfig(IndexRetrieverConfig): """Config for Elasticsearch-based retrievers. Support both vector and text.""" store_config: ElasticsearchStoreConfig = Field(..., description="ElasticsearchStore config.") - vector_store_query_mode: VectorStoreQueryMode = VectorStoreQueryMode.DEFAULT + vector_store_query_mode: VectorStoreQueryMode = Field( + default=VectorStoreQueryMode.DEFAULT, description="default is vector query." + ) class ElasticsearchKeywordRetrieverConfig(ElasticsearchRetrieverConfig): """Config for Elasticsearch-based retrievers. Support text only.""" _no_embedding: bool = PrivateAttr(default=True) - vector_store_query_mode: VectorStoreQueryMode = VectorStoreQueryMode.TEXT_SEARCH + vector_store_query_mode: Literal[VectorStoreQueryMode.TEXT_SEARCH] = Field( + default=VectorStoreQueryMode.TEXT_SEARCH, description="text query only." + ) class BaseRankerConfig(BaseModel): From a22d7d89830970c68bacb7cdf9a1cf33c4e29a18 Mon Sep 17 00:00:00 2001 From: seehi <6580@pm.me> Date: Tue, 26 Mar 2024 16:36:45 +0800 Subject: [PATCH 09/12] add object ranker --- metagpt/rag/factories/ranker.py | 17 +++++- metagpt/rag/rankers/object_ranker.py | 54 +++++++++++++++++ metagpt/rag/schema.py | 5 ++ .../metagpt/rag/rankers/test_object_ranker.py | 60 +++++++++++++++++++ 4 files changed, 134 insertions(+), 2 deletions(-) create mode 100644 metagpt/rag/rankers/object_ranker.py create mode 100644 tests/metagpt/rag/rankers/test_object_ranker.py diff --git a/metagpt/rag/factories/ranker.py b/metagpt/rag/factories/ranker.py index 15dc55bf9..07cb1b929 100644 --- a/metagpt/rag/factories/ranker.py +++ b/metagpt/rag/factories/ranker.py @@ -6,14 +6,24 @@ from llama_index.core.postprocessor.types import BaseNodePostprocessor from llama_index.postprocessor.colbert_rerank import ColbertRerank from metagpt.rag.factories.base import ConfigBasedFactory -from metagpt.rag.schema import BaseRankerConfig, ColbertRerankConfig, LLMRankerConfig +from metagpt.rag.rankers.object_ranker import ObjectSortPostprocessor +from metagpt.rag.schema import ( + BaseRankerConfig, + ColbertRerankConfig, + LLMRankerConfig, + ObjectRankerConfig, +) class RankerFactory(ConfigBasedFactory): """Modify creators for dynamically instance implementation.""" def __init__(self): - creators = {LLMRankerConfig: self._create_llm_ranker, ColbertRerankConfig: self._create_colbert_ranker} + creators = { + LLMRankerConfig: self._create_llm_ranker, + ColbertRerankConfig: self._create_colbert_ranker, + ObjectRankerConfig: self._create_object_ranker, + } super().__init__(creators) def get_rankers(self, configs: list[BaseRankerConfig] = None, **kwargs) -> list[BaseNodePostprocessor]: @@ -30,6 +40,9 @@ class RankerFactory(ConfigBasedFactory): def _create_colbert_ranker(self, config: ColbertRerankConfig, **kwargs) -> LLMRerank: return ColbertRerank(**config.model_dump()) + def _create_object_ranker(self, config: ObjectRankerConfig, **kwargs) -> LLMRerank: + return ObjectSortPostprocessor(**config.model_dump()) + def _extract_llm(self, config: BaseRankerConfig = None, **kwargs) -> LLM: return self._val_from_config_or_kwargs("llm", config, **kwargs) diff --git a/metagpt/rag/rankers/object_ranker.py b/metagpt/rag/rankers/object_ranker.py new file mode 100644 index 000000000..fe45f9395 --- /dev/null +++ b/metagpt/rag/rankers/object_ranker.py @@ -0,0 +1,54 @@ +"""Object ranker.""" + +import heapq +import json +from typing import Literal, Optional + +from llama_index.core.postprocessor.types import BaseNodePostprocessor +from llama_index.core.schema import NodeWithScore, QueryBundle +from pydantic import Field + +from metagpt.rag.schema import ObjectNode + + +class ObjectSortPostprocessor(BaseNodePostprocessor): + """Sorted by object's field, desc or asc. + + Assumes nodes is list of ObjectNode with score. + """ + + field_name: str = Field(..., description="field name of the object, field's value must can be compared.") + order: Literal["desc", "asc"] = Field(default="desc", description="the direction of order.") + top_n: int = 5 + + @classmethod + def class_name(cls) -> str: + return "ObjectSortPostprocessor" + + def _postprocess_nodes( + self, + nodes: list[NodeWithScore], + query_bundle: Optional[QueryBundle] = None, + ) -> list[NodeWithScore]: + """Postprocess nodes.""" + if query_bundle is None: + raise ValueError("Missing query bundle in extra info.") + + if not nodes: + return [] + + self._check_metadata(nodes[0].node) + sort_key = lambda node: json.loads(node.node.metadata["obj_json"])[self.field_name] + return self._get_sort_func()(self.top_n, nodes, key=sort_key) + + def _get_sort_func(self): + return heapq.nlargest if self.order == "desc" else heapq.nsmallest + + def _check_metadata(self, node: ObjectNode): + try: + obj_dict = json.loads(node.metadata.get("obj_json")) + except Exception as e: + raise ValueError(f"Invalid object json in metadata: {node.metadata}, error: {e}") + + if self.field_name not in obj_dict: + raise ValueError(f"Field '{self.field_name}' not found in object: {obj_dict}") diff --git a/metagpt/rag/schema.py b/metagpt/rag/schema.py index 0711f5c83..183f6e0c7 100644 --- a/metagpt/rag/schema.py +++ b/metagpt/rag/schema.py @@ -101,6 +101,11 @@ class ColbertRerankConfig(BaseRankerConfig): keep_retrieval_score: bool = Field(default=False, description="Whether to keep the retrieval score in metadata.") +class ObjectRankerConfig(BaseRankerConfig): + field_name: str = Field(..., description="field name of the object, field's value must can be compared.") + order: Literal["desc", "asc"] = Field(default="desc", description="the direction of order.") + + class BaseIndexConfig(BaseModel): """Common config for index. diff --git a/tests/metagpt/rag/rankers/test_object_ranker.py b/tests/metagpt/rag/rankers/test_object_ranker.py new file mode 100644 index 000000000..7ea6b7488 --- /dev/null +++ b/tests/metagpt/rag/rankers/test_object_ranker.py @@ -0,0 +1,60 @@ +import json + +import pytest +from llama_index.core.schema import NodeWithScore, QueryBundle +from pydantic import BaseModel + +from metagpt.rag.rankers.object_ranker import ObjectSortPostprocessor +from metagpt.rag.schema import ObjectNode + + +class Record(BaseModel): + score: int + + +class TestObjectSortPostprocessor: + @pytest.fixture + def nodes_with_scores(self): + nodes = [ + NodeWithScore(node=ObjectNode(metadata={"obj_json": Record(score=10).model_dump_json()}), score=10), + NodeWithScore(node=ObjectNode(metadata={"obj_json": Record(score=20).model_dump_json()}), score=20), + NodeWithScore(node=ObjectNode(metadata={"obj_json": Record(score=5).model_dump_json()}), score=5), + ] + return nodes + + @pytest.fixture + def query_bundle(self, mocker): + return mocker.MagicMock(spec=QueryBundle) + + def test_sort_descending(self, nodes_with_scores, query_bundle): + postprocessor = ObjectSortPostprocessor(field_name="score", order="desc") + sorted_nodes = postprocessor._postprocess_nodes(nodes_with_scores, query_bundle) + assert [node.score for node in sorted_nodes] == [20, 10, 5] + + def test_sort_ascending(self, nodes_with_scores, query_bundle): + postprocessor = ObjectSortPostprocessor(field_name="score", order="asc") + sorted_nodes = postprocessor._postprocess_nodes(nodes_with_scores, query_bundle) + assert [node.score for node in sorted_nodes] == [5, 10, 20] + + def test_top_n_limit(self, nodes_with_scores, query_bundle): + postprocessor = ObjectSortPostprocessor(field_name="score", order="desc", top_n=2) + sorted_nodes = postprocessor._postprocess_nodes(nodes_with_scores, query_bundle) + assert len(sorted_nodes) == 2 + assert [node.score for node in sorted_nodes] == [20, 10] + + def test_invalid_json_metadata(self, query_bundle): + nodes = [NodeWithScore(node=ObjectNode(metadata={"obj_json": "invalid_json"}), score=10)] + postprocessor = ObjectSortPostprocessor(field_name="score", order="desc") + with pytest.raises(ValueError): + postprocessor._postprocess_nodes(nodes, query_bundle) + + def test_missing_query_bundle(self, nodes_with_scores): + postprocessor = ObjectSortPostprocessor(field_name="score", order="desc") + with pytest.raises(ValueError): + postprocessor._postprocess_nodes(nodes_with_scores, query_bundle=None) + + def test_field_not_found_in_object(self): + nodes = [NodeWithScore(node=ObjectNode(metadata={"obj_json": json.dumps({"not_score": 10})}), score=10)] + postprocessor = ObjectSortPostprocessor(field_name="score", order="desc") + with pytest.raises(ValueError): + postprocessor._postprocess_nodes(nodes) From 1eb141a45f794af987f171442820703e447e3e53 Mon Sep 17 00:00:00 2001 From: seehi <6580@pm.me> Date: Tue, 26 Mar 2024 16:40:10 +0800 Subject: [PATCH 10/12] add object ranker --- metagpt/rag/rankers/object_ranker.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/metagpt/rag/rankers/object_ranker.py b/metagpt/rag/rankers/object_ranker.py index fe45f9395..b8456803f 100644 --- a/metagpt/rag/rankers/object_ranker.py +++ b/metagpt/rag/rankers/object_ranker.py @@ -38,12 +38,10 @@ class ObjectSortPostprocessor(BaseNodePostprocessor): return [] self._check_metadata(nodes[0].node) + sort_key = lambda node: json.loads(node.node.metadata["obj_json"])[self.field_name] return self._get_sort_func()(self.top_n, nodes, key=sort_key) - def _get_sort_func(self): - return heapq.nlargest if self.order == "desc" else heapq.nsmallest - def _check_metadata(self, node: ObjectNode): try: obj_dict = json.loads(node.metadata.get("obj_json")) @@ -52,3 +50,6 @@ class ObjectSortPostprocessor(BaseNodePostprocessor): if self.field_name not in obj_dict: raise ValueError(f"Field '{self.field_name}' not found in object: {obj_dict}") + + def _get_sort_func(self): + return heapq.nlargest if self.order == "desc" else heapq.nsmallest From 8d98ce34e54eb6250f1f2cf60f5d4dd66d462a5d Mon Sep 17 00:00:00 2001 From: seehi <6580@pm.me> Date: Wed, 27 Mar 2024 11:15:10 +0800 Subject: [PATCH 11/12] fix by cr --- examples/rag_pipeline.py | 29 +++++++++-------------------- requirements.txt | 10 ---------- setup.py | 2 ++ 3 files changed, 11 insertions(+), 30 deletions(-) diff --git a/examples/rag_pipeline.py b/examples/rag_pipeline.py index ae6e7b7bc..47137c0a4 100644 --- a/examples/rag_pipeline.py +++ b/examples/rag_pipeline.py @@ -1,7 +1,6 @@ """RAG pipeline""" import asyncio -from functools import wraps from pydantic import BaseModel @@ -18,6 +17,7 @@ from metagpt.rag.schema import ( FAISSRetrieverConfig, LLMRankerConfig, ) +from metagpt.utils.exceptions import handle_exception DOC_PATH = EXAMPLE_DATA_PATH / "rag/writer.txt" QUESTION = "What are key qualities to be a good writer?" @@ -28,17 +28,6 @@ TRAVEL_QUESTION = "What does Bob like?" LLM_TIP = "If you not sure, just answer I don't know." -def catch_exception(func): - @wraps(func) - async def wrapper(*args, **kwargs): - try: - return await func(*args, **kwargs) - except Exception as e: - logger.error(f"{func.__name__} exception: {e}") - - return wrapper - - class Player(BaseModel): """To demonstrate rag add objs.""" @@ -122,7 +111,7 @@ class RAGExample: self.engine.add_docs([travel_filepath]) await self.run_pipeline(question=travel_question, print_title=False) - @catch_exception + @handle_exception async def add_objects(self, print_title=True): """This example show how to add objects. @@ -180,21 +169,21 @@ class RAGExample: """ self._print_title("Init And Query ChromaDB") - # 1.save index + # 1. save index output_dir = DATA_PATH / "rag" SimpleEngine.from_docs( input_files=[TRAVEL_DOC_PATH], retriever_configs=[ChromaRetrieverConfig(persist_path=output_dir)], ) - # 2.load index + # 2. load index engine = SimpleEngine.from_index(index_config=ChromaIndexConfig(persist_path=output_dir)) - # 3.query + # 3. query answer = await engine.aquery(TRAVEL_QUESTION) self._print_query_result(answer) - @catch_exception + @handle_exception async def init_and_query_es(self): """This example show how to use es. how to save and load index. will print something like: @@ -205,17 +194,17 @@ class RAGExample: """ self._print_title("Init And Query Elasticsearch") - # 1.create es index and save docs + # 1. create es index and save docs store_config = ElasticsearchStoreConfig(index_name="travel", es_url="http://127.0.0.1:9200") engine = SimpleEngine.from_docs( input_files=[TRAVEL_DOC_PATH], retriever_configs=[ElasticsearchRetrieverConfig(store_config=store_config)], ) - # 2.load index + # 2. load index engine = SimpleEngine.from_index(index_config=ElasticsearchIndexConfig(store_config=store_config)) - # 3.query + # 3. query answer = await engine.aquery(TRAVEL_QUESTION) self._print_query_result(answer) diff --git a/requirements.txt b/requirements.txt index fef56e810..da8aa26b2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,16 +10,6 @@ typer==0.9.0 # godot==0.1.1 # google_api_python_client==2.93.0 # Used by search_engine.py lancedb==0.4.0 -llama-index-core==0.10.15 -llama-index-embeddings-azure-openai==0.1.6 -llama-index-embeddings-openai==0.1.5 -llama-index-llms-azure-openai==0.1.4 -llama-index-readers-file==0.1.4 -llama-index-retrievers-bm25==0.1.3 -llama-index-vector-stores-faiss==0.1.1 -llama-index-vector-stores-elasticsearch==0.1.6 -llama-index-postprocessor-colbert-rerank==0.1.1 -chromadb==0.4.23 loguru==0.6.0 meilisearch==0.21.0 numpy==1.24.3 diff --git a/setup.py b/setup.py index f834b4c44..c728872ef 100644 --- a/setup.py +++ b/setup.py @@ -36,6 +36,8 @@ extras_require = { "llama-index-readers-file==0.1.4", "llama-index-retrievers-bm25==0.1.3", "llama-index-vector-stores-faiss==0.1.1", + "llama-index-vector-stores-elasticsearch==0.1.6", + "llama-index-postprocessor-colbert-rerank==0.1.1", "chromadb==0.4.23", ], } From 90e1b629341abbf12a8b6f16910d4548c8ea2c79 Mon Sep 17 00:00:00 2001 From: seehi <6580@pm.me> Date: Wed, 27 Mar 2024 11:28:23 +0800 Subject: [PATCH 12/12] rm unnecessary comment --- examples/rag_pipeline.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/examples/rag_pipeline.py b/examples/rag_pipeline.py index 47137c0a4..b5111b75c 100644 --- a/examples/rag_pipeline.py +++ b/examples/rag_pipeline.py @@ -189,8 +189,6 @@ class RAGExample: Query Result: Bob likes traveling. - - If `Unclosed client session`, it's llamaindex elasticsearch problem, maybe fixed later. """ self._print_title("Init And Query Elasticsearch")