rag add es

This commit is contained in:
seehi 2024-03-21 12:04:06 +08:00
parent 8c218a1e55
commit 191a86f93e
7 changed files with 157 additions and 31 deletions

View file

@ -1,6 +1,7 @@
"""RAG pipeline"""
import asyncio
from functools import wraps
from pydantic import BaseModel
@ -11,6 +12,9 @@ from metagpt.rag.schema import (
BM25RetrieverConfig,
ChromaIndexConfig,
ChromaRetrieverConfig,
ElasticsearchIndexConfig,
ElasticsearchRetrieverConfig,
ElasticsearchStoreConfig,
FAISSRetrieverConfig,
LLMRankerConfig,
)
@ -24,6 +28,17 @@ TRAVEL_QUESTION = "What does Bob like?"
LLM_TIP = "If you not sure, just answer I don't know."
def catch_exception(func):
@wraps(func)
async def wrapper(*args, **kwargs):
try:
return await func(*args, **kwargs)
except Exception as e:
logger.error(f"{func.__name__} exception: {e}")
return wrapper
class Player(BaseModel):
"""To demonstrate rag add objs."""
@ -39,12 +54,22 @@ class Player(BaseModel):
class RAGExample:
"""Show how to use RAG."""
def __init__(self):
self.engine = SimpleEngine.from_docs(
input_files=[DOC_PATH],
retriever_configs=[FAISSRetrieverConfig(), BM25RetrieverConfig()],
ranker_configs=[LLMRankerConfig()],
)
def __init__(self, engine: SimpleEngine = None):
self._engine = engine
@property
def engine(self):
if not self._engine:
self._engine = SimpleEngine.from_docs(
input_files=[DOC_PATH],
retriever_configs=[FAISSRetrieverConfig(), BM25RetrieverConfig()],
ranker_configs=[LLMRankerConfig()],
)
return self._engine
@engine.setter
def engine(self, value: SimpleEngine):
self._engine = value
async def run_pipeline(self, question=QUESTION, print_title=True):
"""This example run rag pipeline, use faiss&bm25 retriever and llm ranker, will print something like:
@ -97,6 +122,7 @@ class RAGExample:
self.engine.add_docs([travel_filepath])
await self.run_pipeline(question=travel_question, print_title=False)
@catch_exception
async def add_objects(self, print_title=True):
"""This example show how to add objects.
@ -154,20 +180,43 @@ class RAGExample:
"""
self._print_title("Init And Query ChromaDB")
# save index
# 1.save index
output_dir = DATA_PATH / "rag"
SimpleEngine.from_docs(
input_files=[TRAVEL_DOC_PATH],
retriever_configs=[ChromaRetrieverConfig(persist_path=output_dir)],
)
# load index
engine = SimpleEngine.from_index(
index_config=ChromaIndexConfig(persist_path=output_dir),
# 2.load index
engine = SimpleEngine.from_index(index_config=ChromaIndexConfig(persist_path=output_dir))
# 3.query
answer = await engine.aquery(TRAVEL_QUESTION)
self._print_query_result(answer)
@catch_exception
async def init_and_query_es(self):
"""This example show how to use es. how to save and load index. will print something like:
Query Result:
Bob likes traveling.
If `Unclosed client session`, it's llamaindex elasticsearch problem, maybe fixed later.
"""
self._print_title("Init And Query Elasticsearch")
# 1.create es index and save docs
store_config = ElasticsearchStoreConfig(index_name="travel", es_url="http://127.0.0.1:9200")
engine = SimpleEngine.from_docs(
input_files=[TRAVEL_DOC_PATH],
retriever_configs=[ElasticsearchRetrieverConfig(store_config=store_config)],
)
# query
answer = engine.query(TRAVEL_QUESTION)
# 2.load index
engine = SimpleEngine.from_index(index_config=ElasticsearchIndexConfig(store_config=store_config))
# 3.query
answer = await engine.aquery(TRAVEL_QUESTION)
self._print_query_result(answer)
@staticmethod
@ -205,6 +254,7 @@ async def main():
await e.add_objects()
await e.init_objects()
await e.init_and_query_chromadb()
await e.init_and_query_es()
if __name__ == "__main__":

View file

@ -4,6 +4,8 @@ import chromadb
from llama_index.core import StorageContext, VectorStoreIndex, load_index_from_storage
from llama_index.core.embeddings import BaseEmbedding
from llama_index.core.indices.base import BaseIndex
from llama_index.core.vector_stores.types import BasePydanticVectorStore
from llama_index.vector_stores.elasticsearch import ElasticsearchStore
from llama_index.vector_stores.faiss import FaissVectorStore
from metagpt.rag.factories.base import ConfigBasedFactory
@ -11,6 +13,7 @@ from metagpt.rag.schema import (
BaseIndexConfig,
BM25IndexConfig,
ChromaIndexConfig,
ElasticsearchIndexConfig,
FAISSIndexConfig,
)
from metagpt.rag.vector_stores.chroma import ChromaVectorStore
@ -22,6 +25,7 @@ class RAGIndexFactory(ConfigBasedFactory):
FAISSIndexConfig: self._create_faiss,
ChromaIndexConfig: self._create_chroma,
BM25IndexConfig: self._create_bm25,
ElasticsearchIndexConfig: self._create_es,
}
super().__init__(creators)
@ -30,31 +34,44 @@ class RAGIndexFactory(ConfigBasedFactory):
return super().get_instance(config, **kwargs)
def _create_faiss(self, config: FAISSIndexConfig, **kwargs) -> VectorStoreIndex:
embed_model = self._extract_embed_model(config, **kwargs)
vector_store = FaissVectorStore.from_persist_dir(str(config.persist_path))
storage_context = StorageContext.from_defaults(vector_store=vector_store, persist_dir=config.persist_path)
index = load_index_from_storage(storage_context=storage_context, embed_model=embed_model)
return index
return self._index_from_storage(storage_context=storage_context, config=config, **kwargs)
def _create_bm25(self, config: BM25IndexConfig, **kwargs) -> VectorStoreIndex:
storage_context = StorageContext.from_defaults(persist_dir=config.persist_path)
return self._index_from_storage(storage_context=storage_context, config=config, **kwargs)
def _create_chroma(self, config: ChromaIndexConfig, **kwargs) -> VectorStoreIndex:
embed_model = self._extract_embed_model(config, **kwargs)
db = chromadb.PersistentClient(str(config.persist_path))
chroma_collection = db.get_or_create_collection(config.collection_name)
vector_store = ChromaVectorStore(chroma_collection=chroma_collection)
index = VectorStoreIndex.from_vector_store(
vector_store,
embed_model=embed_model,
)
return index
def _create_bm25(self, config: BM25IndexConfig, **kwargs) -> VectorStoreIndex:
return self._index_from_vector_store(vector_store=vector_store, config=config, **kwargs)
def _create_es(self, config: ElasticsearchIndexConfig, **kwargs) -> VectorStoreIndex:
vector_store = ElasticsearchStore(**config.store_config.model_dump())
return self._index_from_vector_store(vector_store=vector_store, config=config, **kwargs)
def _index_from_storage(
self, storage_context: StorageContext, config: BaseIndexConfig, **kwargs
) -> VectorStoreIndex:
embed_model = self._extract_embed_model(config, **kwargs)
storage_context = StorageContext.from_defaults(persist_dir=config.persist_path)
index = load_index_from_storage(storage_context=storage_context, embed_model=embed_model)
return index
return load_index_from_storage(storage_context=storage_context, embed_model=embed_model)
def _index_from_vector_store(
self, vector_store: BasePydanticVectorStore, config: BaseIndexConfig, **kwargs
) -> VectorStoreIndex:
embed_model = self._extract_embed_model(config, **kwargs)
return VectorStoreIndex.from_vector_store(
vector_store=vector_store,
embed_model=embed_model,
)
def _extract_embed_model(self, config, **kwargs) -> BaseEmbedding:
return self._val_from_config_or_kwargs("embed_model", config, **kwargs)

View file

@ -6,18 +6,21 @@ import chromadb
import faiss
from llama_index.core import StorageContext, VectorStoreIndex
from llama_index.core.vector_stores.types import BasePydanticVectorStore
from llama_index.vector_stores.elasticsearch import ElasticsearchStore
from llama_index.vector_stores.faiss import FaissVectorStore
from metagpt.rag.factories.base import ConfigBasedFactory
from metagpt.rag.retrievers.base import RAGRetriever
from metagpt.rag.retrievers.bm25_retriever import DynamicBM25Retriever
from metagpt.rag.retrievers.chroma_retriever import ChromaRetriever
from metagpt.rag.retrievers.es_retriever import ElasticsearchRetriever
from metagpt.rag.retrievers.faiss_retriever import FAISSRetriever
from metagpt.rag.retrievers.hybrid_retriever import SimpleHybridRetriever
from metagpt.rag.schema import (
BaseRetrieverConfig,
BM25RetrieverConfig,
ChromaRetrieverConfig,
ElasticsearchRetrieverConfig,
FAISSRetrieverConfig,
IndexRetrieverConfig,
)
@ -32,6 +35,7 @@ class RetrieverFactory(ConfigBasedFactory):
FAISSRetrieverConfig: self._create_faiss_retriever,
BM25RetrieverConfig: self._create_bm25_retriever,
ChromaRetrieverConfig: self._create_chroma_retriever,
ElasticsearchRetrieverConfig: self._create_es_retriever,
}
super().__init__(creators)
@ -53,20 +57,29 @@ class RetrieverFactory(ConfigBasedFactory):
def _create_faiss_retriever(self, config: FAISSRetrieverConfig, **kwargs) -> FAISSRetriever:
vector_store = FaissVectorStore(faiss_index=faiss.IndexFlatL2(config.dimensions))
config.index = self._build_index_from_vector_store(config, vector_store, **kwargs)
return FAISSRetriever(**config.model_dump())
def _create_bm25_retriever(self, config: BM25RetrieverConfig, **kwargs) -> DynamicBM25Retriever:
config.index = copy.deepcopy(self._extract_index(config, **kwargs))
nodes = list(config.index.docstore.docs.values())
return DynamicBM25Retriever(nodes=nodes, **config.model_dump())
return DynamicBM25Retriever(nodes=list(config.index.docstore.docs.values()), **config.model_dump())
def _create_chroma_retriever(self, config: ChromaRetrieverConfig, **kwargs) -> ChromaRetriever:
db = chromadb.PersistentClient(path=str(config.persist_path))
chroma_collection = db.get_or_create_collection(config.collection_name)
vector_store = ChromaVectorStore(chroma_collection=chroma_collection)
config.index = self._build_index_from_vector_store(config, vector_store, **kwargs)
return ChromaRetriever(**config.model_dump())
def _create_es_retriever(self, config: ElasticsearchRetrieverConfig, **kwargs) -> ElasticsearchRetriever:
vector_store = ElasticsearchStore(**config.store_config.model_dump())
config.index = self._build_index_from_vector_store(config, vector_store, **kwargs)
return ElasticsearchRetriever(**config.model_dump())
def _extract_index(self, config: BaseRetrieverConfig = None, **kwargs) -> VectorStoreIndex:
return self._val_from_config_or_kwargs("index", config, **kwargs)

View file

@ -0,0 +1,17 @@
"""Elasticsearch retriever."""
from llama_index.core.retrievers import VectorIndexRetriever
from llama_index.core.schema import BaseNode
class ElasticsearchRetriever(VectorIndexRetriever):
"""Elasticsearch retriever."""
def add_nodes(self, nodes: list[BaseNode], **kwargs) -> None:
"""Support add nodes."""
self._index.insert_nodes(nodes, **kwargs)
def persist(self, persist_dir: str, **kwargs) -> None:
"""Support persist.
Elasticsearch automatically saves, so there is no need to implement."""

View file

@ -8,7 +8,7 @@ class FAISSRetriever(VectorIndexRetriever):
"""FAISS retriever."""
def add_nodes(self, nodes: list[BaseNode], **kwargs) -> None:
"""Support add nodes"""
"""Support add nodes."""
self._index.insert_nodes(nodes, **kwargs)
def persist(self, persist_dir: str, **kwargs) -> None:

View file

@ -6,6 +6,7 @@ from typing import Any, Union
from llama_index.core.embeddings import BaseEmbedding
from llama_index.core.indices.base import BaseIndex
from llama_index.core.schema import TextNode
from llama_index.core.vector_stores.types import VectorStoreQueryMode
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr
from metagpt.rag.interface import RAGObject
@ -46,6 +47,24 @@ class ChromaRetrieverConfig(IndexRetrieverConfig):
collection_name: str = Field(default="metagpt", description="The name of the collection.")
class ElasticsearchStoreConfig(BaseModel):
index_name: str = Field(default="metagpt", description="Name of the Elasticsearch index.")
es_url: str = Field(default=None, description="Elasticsearch URL.")
es_cloud_id: str = Field(default=None, description="Elasticsearch cloud ID.")
es_api_key: str = Field(default=None, description="Elasticsearch API key.")
es_user: str = Field(default=None, description="Elasticsearch username.")
es_password: str = Field(default=None, description="Elasticsearch password.")
batch_size: int = Field(default=200, description="Batch size for bulk indexing.")
distance_strategy: str = Field(default="COSINE", description="Distance strategy to use for similarity search.")
class ElasticsearchRetrieverConfig(IndexRetrieverConfig):
"""Config for Elasticsearch-based retrievers."""
store_config: ElasticsearchStoreConfig = Field(..., description="ElasticsearchStore config.")
vector_store_query_mode: VectorStoreQueryMode = VectorStoreQueryMode.DEFAULT
class BaseRankerConfig(BaseModel):
"""Common config for rankers.
@ -53,7 +72,6 @@ class BaseRankerConfig(BaseModel):
"""
model_config = ConfigDict(arbitrary_types_allowed=True)
top_n: int = Field(default=5, description="The number of top results to return.")
@ -72,6 +90,7 @@ class BaseIndexConfig(BaseModel):
If add new subconfig, it is necessary to add the corresponding instance implementation in rag.factories.index.
"""
model_config = ConfigDict(arbitrary_types_allowed=True)
persist_path: Union[str, Path] = Field(description="The directory of saved data.")
@ -97,6 +116,13 @@ class BM25IndexConfig(BaseIndexConfig):
_no_embedding: bool = PrivateAttr(default=True)
class ElasticsearchIndexConfig(VectorIndexConfig):
"""Config for es-based index."""
store_config: ElasticsearchStoreConfig = Field(..., description="ElasticsearchStore config.")
persist_path: Union[str, Path] = ""
class ObjectNodeMetadata(BaseModel):
"""Metadata of ObjectNode."""

View file

@ -17,6 +17,7 @@ llama-index-llms-azure-openai==0.1.4
llama-index-readers-file==0.1.4
llama-index-retrievers-bm25==0.1.3
llama-index-vector-stores-faiss==0.1.1
llama-index-vector-stores-elasticsearch==0.1.5
chromadb==0.4.23
loguru==0.6.0
meilisearch==0.21.0
@ -76,3 +77,5 @@ Pillow
imap_tools==1.5.0 # Used by metagpt/tools/libs/email_login.py
qianfan==0.3.2
dashscope==1.14.1
rank-bm25==0.2.2 # for tool recommendation
jieba==0.42.1 # for tool recommendation