mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-05-27 14:25:20 +02:00
rag add es
This commit is contained in:
parent
8c218a1e55
commit
191a86f93e
7 changed files with 157 additions and 31 deletions
|
|
@ -4,6 +4,8 @@ import chromadb
|
|||
from llama_index.core import StorageContext, VectorStoreIndex, load_index_from_storage
|
||||
from llama_index.core.embeddings import BaseEmbedding
|
||||
from llama_index.core.indices.base import BaseIndex
|
||||
from llama_index.core.vector_stores.types import BasePydanticVectorStore
|
||||
from llama_index.vector_stores.elasticsearch import ElasticsearchStore
|
||||
from llama_index.vector_stores.faiss import FaissVectorStore
|
||||
|
||||
from metagpt.rag.factories.base import ConfigBasedFactory
|
||||
|
|
@ -11,6 +13,7 @@ from metagpt.rag.schema import (
|
|||
BaseIndexConfig,
|
||||
BM25IndexConfig,
|
||||
ChromaIndexConfig,
|
||||
ElasticsearchIndexConfig,
|
||||
FAISSIndexConfig,
|
||||
)
|
||||
from metagpt.rag.vector_stores.chroma import ChromaVectorStore
|
||||
|
|
@ -22,6 +25,7 @@ class RAGIndexFactory(ConfigBasedFactory):
|
|||
FAISSIndexConfig: self._create_faiss,
|
||||
ChromaIndexConfig: self._create_chroma,
|
||||
BM25IndexConfig: self._create_bm25,
|
||||
ElasticsearchIndexConfig: self._create_es,
|
||||
}
|
||||
super().__init__(creators)
|
||||
|
||||
|
|
@ -30,31 +34,44 @@ class RAGIndexFactory(ConfigBasedFactory):
|
|||
return super().get_instance(config, **kwargs)
|
||||
|
||||
def _create_faiss(self, config: FAISSIndexConfig, **kwargs) -> VectorStoreIndex:
|
||||
embed_model = self._extract_embed_model(config, **kwargs)
|
||||
|
||||
vector_store = FaissVectorStore.from_persist_dir(str(config.persist_path))
|
||||
storage_context = StorageContext.from_defaults(vector_store=vector_store, persist_dir=config.persist_path)
|
||||
index = load_index_from_storage(storage_context=storage_context, embed_model=embed_model)
|
||||
return index
|
||||
|
||||
return self._index_from_storage(storage_context=storage_context, config=config, **kwargs)
|
||||
|
||||
def _create_bm25(self, config: BM25IndexConfig, **kwargs) -> VectorStoreIndex:
|
||||
storage_context = StorageContext.from_defaults(persist_dir=config.persist_path)
|
||||
|
||||
return self._index_from_storage(storage_context=storage_context, config=config, **kwargs)
|
||||
|
||||
def _create_chroma(self, config: ChromaIndexConfig, **kwargs) -> VectorStoreIndex:
|
||||
embed_model = self._extract_embed_model(config, **kwargs)
|
||||
|
||||
db = chromadb.PersistentClient(str(config.persist_path))
|
||||
chroma_collection = db.get_or_create_collection(config.collection_name)
|
||||
vector_store = ChromaVectorStore(chroma_collection=chroma_collection)
|
||||
index = VectorStoreIndex.from_vector_store(
|
||||
vector_store,
|
||||
embed_model=embed_model,
|
||||
)
|
||||
return index
|
||||
|
||||
def _create_bm25(self, config: BM25IndexConfig, **kwargs) -> VectorStoreIndex:
|
||||
return self._index_from_vector_store(vector_store=vector_store, config=config, **kwargs)
|
||||
|
||||
def _create_es(self, config: ElasticsearchIndexConfig, **kwargs) -> VectorStoreIndex:
|
||||
vector_store = ElasticsearchStore(**config.store_config.model_dump())
|
||||
|
||||
return self._index_from_vector_store(vector_store=vector_store, config=config, **kwargs)
|
||||
|
||||
def _index_from_storage(
|
||||
self, storage_context: StorageContext, config: BaseIndexConfig, **kwargs
|
||||
) -> VectorStoreIndex:
|
||||
embed_model = self._extract_embed_model(config, **kwargs)
|
||||
|
||||
storage_context = StorageContext.from_defaults(persist_dir=config.persist_path)
|
||||
index = load_index_from_storage(storage_context=storage_context, embed_model=embed_model)
|
||||
return index
|
||||
return load_index_from_storage(storage_context=storage_context, embed_model=embed_model)
|
||||
|
||||
def _index_from_vector_store(
|
||||
self, vector_store: BasePydanticVectorStore, config: BaseIndexConfig, **kwargs
|
||||
) -> VectorStoreIndex:
|
||||
embed_model = self._extract_embed_model(config, **kwargs)
|
||||
|
||||
return VectorStoreIndex.from_vector_store(
|
||||
vector_store=vector_store,
|
||||
embed_model=embed_model,
|
||||
)
|
||||
|
||||
def _extract_embed_model(self, config, **kwargs) -> BaseEmbedding:
|
||||
return self._val_from_config_or_kwargs("embed_model", config, **kwargs)
|
||||
|
|
|
|||
|
|
@ -6,18 +6,21 @@ import chromadb
|
|||
import faiss
|
||||
from llama_index.core import StorageContext, VectorStoreIndex
|
||||
from llama_index.core.vector_stores.types import BasePydanticVectorStore
|
||||
from llama_index.vector_stores.elasticsearch import ElasticsearchStore
|
||||
from llama_index.vector_stores.faiss import FaissVectorStore
|
||||
|
||||
from metagpt.rag.factories.base import ConfigBasedFactory
|
||||
from metagpt.rag.retrievers.base import RAGRetriever
|
||||
from metagpt.rag.retrievers.bm25_retriever import DynamicBM25Retriever
|
||||
from metagpt.rag.retrievers.chroma_retriever import ChromaRetriever
|
||||
from metagpt.rag.retrievers.es_retriever import ElasticsearchRetriever
|
||||
from metagpt.rag.retrievers.faiss_retriever import FAISSRetriever
|
||||
from metagpt.rag.retrievers.hybrid_retriever import SimpleHybridRetriever
|
||||
from metagpt.rag.schema import (
|
||||
BaseRetrieverConfig,
|
||||
BM25RetrieverConfig,
|
||||
ChromaRetrieverConfig,
|
||||
ElasticsearchRetrieverConfig,
|
||||
FAISSRetrieverConfig,
|
||||
IndexRetrieverConfig,
|
||||
)
|
||||
|
|
@ -32,6 +35,7 @@ class RetrieverFactory(ConfigBasedFactory):
|
|||
FAISSRetrieverConfig: self._create_faiss_retriever,
|
||||
BM25RetrieverConfig: self._create_bm25_retriever,
|
||||
ChromaRetrieverConfig: self._create_chroma_retriever,
|
||||
ElasticsearchRetrieverConfig: self._create_es_retriever,
|
||||
}
|
||||
super().__init__(creators)
|
||||
|
||||
|
|
@ -53,20 +57,29 @@ class RetrieverFactory(ConfigBasedFactory):
|
|||
def _create_faiss_retriever(self, config: FAISSRetrieverConfig, **kwargs) -> FAISSRetriever:
|
||||
vector_store = FaissVectorStore(faiss_index=faiss.IndexFlatL2(config.dimensions))
|
||||
config.index = self._build_index_from_vector_store(config, vector_store, **kwargs)
|
||||
|
||||
return FAISSRetriever(**config.model_dump())
|
||||
|
||||
def _create_bm25_retriever(self, config: BM25RetrieverConfig, **kwargs) -> DynamicBM25Retriever:
|
||||
config.index = copy.deepcopy(self._extract_index(config, **kwargs))
|
||||
nodes = list(config.index.docstore.docs.values())
|
||||
return DynamicBM25Retriever(nodes=nodes, **config.model_dump())
|
||||
|
||||
return DynamicBM25Retriever(nodes=list(config.index.docstore.docs.values()), **config.model_dump())
|
||||
|
||||
def _create_chroma_retriever(self, config: ChromaRetrieverConfig, **kwargs) -> ChromaRetriever:
|
||||
db = chromadb.PersistentClient(path=str(config.persist_path))
|
||||
chroma_collection = db.get_or_create_collection(config.collection_name)
|
||||
|
||||
vector_store = ChromaVectorStore(chroma_collection=chroma_collection)
|
||||
config.index = self._build_index_from_vector_store(config, vector_store, **kwargs)
|
||||
|
||||
return ChromaRetriever(**config.model_dump())
|
||||
|
||||
def _create_es_retriever(self, config: ElasticsearchRetrieverConfig, **kwargs) -> ElasticsearchRetriever:
|
||||
vector_store = ElasticsearchStore(**config.store_config.model_dump())
|
||||
config.index = self._build_index_from_vector_store(config, vector_store, **kwargs)
|
||||
|
||||
return ElasticsearchRetriever(**config.model_dump())
|
||||
|
||||
def _extract_index(self, config: BaseRetrieverConfig = None, **kwargs) -> VectorStoreIndex:
|
||||
return self._val_from_config_or_kwargs("index", config, **kwargs)
|
||||
|
||||
|
|
|
|||
17
metagpt/rag/retrievers/es_retriever.py
Normal file
17
metagpt/rag/retrievers/es_retriever.py
Normal file
|
|
@ -0,0 +1,17 @@
|
|||
"""Elasticsearch retriever."""
|
||||
|
||||
from llama_index.core.retrievers import VectorIndexRetriever
|
||||
from llama_index.core.schema import BaseNode
|
||||
|
||||
|
||||
class ElasticsearchRetriever(VectorIndexRetriever):
|
||||
"""Elasticsearch retriever."""
|
||||
|
||||
def add_nodes(self, nodes: list[BaseNode], **kwargs) -> None:
|
||||
"""Support add nodes."""
|
||||
self._index.insert_nodes(nodes, **kwargs)
|
||||
|
||||
def persist(self, persist_dir: str, **kwargs) -> None:
|
||||
"""Support persist.
|
||||
|
||||
Elasticsearch automatically saves, so there is no need to implement."""
|
||||
|
|
@ -8,7 +8,7 @@ class FAISSRetriever(VectorIndexRetriever):
|
|||
"""FAISS retriever."""
|
||||
|
||||
def add_nodes(self, nodes: list[BaseNode], **kwargs) -> None:
|
||||
"""Support add nodes"""
|
||||
"""Support add nodes."""
|
||||
self._index.insert_nodes(nodes, **kwargs)
|
||||
|
||||
def persist(self, persist_dir: str, **kwargs) -> None:
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ from typing import Any, Union
|
|||
from llama_index.core.embeddings import BaseEmbedding
|
||||
from llama_index.core.indices.base import BaseIndex
|
||||
from llama_index.core.schema import TextNode
|
||||
from llama_index.core.vector_stores.types import VectorStoreQueryMode
|
||||
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr
|
||||
|
||||
from metagpt.rag.interface import RAGObject
|
||||
|
|
@ -46,6 +47,24 @@ class ChromaRetrieverConfig(IndexRetrieverConfig):
|
|||
collection_name: str = Field(default="metagpt", description="The name of the collection.")
|
||||
|
||||
|
||||
class ElasticsearchStoreConfig(BaseModel):
|
||||
index_name: str = Field(default="metagpt", description="Name of the Elasticsearch index.")
|
||||
es_url: str = Field(default=None, description="Elasticsearch URL.")
|
||||
es_cloud_id: str = Field(default=None, description="Elasticsearch cloud ID.")
|
||||
es_api_key: str = Field(default=None, description="Elasticsearch API key.")
|
||||
es_user: str = Field(default=None, description="Elasticsearch username.")
|
||||
es_password: str = Field(default=None, description="Elasticsearch password.")
|
||||
batch_size: int = Field(default=200, description="Batch size for bulk indexing.")
|
||||
distance_strategy: str = Field(default="COSINE", description="Distance strategy to use for similarity search.")
|
||||
|
||||
|
||||
class ElasticsearchRetrieverConfig(IndexRetrieverConfig):
|
||||
"""Config for Elasticsearch-based retrievers."""
|
||||
|
||||
store_config: ElasticsearchStoreConfig = Field(..., description="ElasticsearchStore config.")
|
||||
vector_store_query_mode: VectorStoreQueryMode = VectorStoreQueryMode.DEFAULT
|
||||
|
||||
|
||||
class BaseRankerConfig(BaseModel):
|
||||
"""Common config for rankers.
|
||||
|
||||
|
|
@ -53,7 +72,6 @@ class BaseRankerConfig(BaseModel):
|
|||
"""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
top_n: int = Field(default=5, description="The number of top results to return.")
|
||||
|
||||
|
||||
|
|
@ -72,6 +90,7 @@ class BaseIndexConfig(BaseModel):
|
|||
If add new subconfig, it is necessary to add the corresponding instance implementation in rag.factories.index.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
persist_path: Union[str, Path] = Field(description="The directory of saved data.")
|
||||
|
||||
|
||||
|
|
@ -97,6 +116,13 @@ class BM25IndexConfig(BaseIndexConfig):
|
|||
_no_embedding: bool = PrivateAttr(default=True)
|
||||
|
||||
|
||||
class ElasticsearchIndexConfig(VectorIndexConfig):
|
||||
"""Config for es-based index."""
|
||||
|
||||
store_config: ElasticsearchStoreConfig = Field(..., description="ElasticsearchStore config.")
|
||||
persist_path: Union[str, Path] = ""
|
||||
|
||||
|
||||
class ObjectNodeMetadata(BaseModel):
|
||||
"""Metadata of ObjectNode."""
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue