integrate milvus

This commit is contained in:
Jacksonxhx 2024-08-14 18:16:18 +08:00
parent 490203d20f
commit 986fb784aa
6 changed files with 92 additions and 2 deletions

View file

@ -8,6 +8,7 @@ from llama_index.core.vector_stores.types import BasePydanticVectorStore
from llama_index.vector_stores.chroma import ChromaVectorStore
from llama_index.vector_stores.elasticsearch import ElasticsearchStore
from llama_index.vector_stores.faiss import FaissVectorStore
from llama_index.vector_stores.milvus import MilvusVectorStore
from metagpt.rag.factories.base import ConfigBasedFactory
from metagpt.rag.schema import (
@ -17,6 +18,7 @@ from metagpt.rag.schema import (
ElasticsearchIndexConfig,
ElasticsearchKeywordIndexConfig,
FAISSIndexConfig,
MilvusIndexConfig,
)
@ -28,6 +30,7 @@ class RAGIndexFactory(ConfigBasedFactory):
BM25IndexConfig: self._create_bm25,
ElasticsearchIndexConfig: self._create_es,
ElasticsearchKeywordIndexConfig: self._create_es,
MilvusIndexConfig: self._create_milvus
}
super().__init__(creators)
@ -46,6 +49,11 @@ class RAGIndexFactory(ConfigBasedFactory):
return self._index_from_storage(storage_context=storage_context, config=config, **kwargs)
def _create_milvus(self, config: MilvusIndexConfig, **kwargs) -> VectorStoreIndex:
vector_store = MilvusVectorStore(collection_name=config.collection_name, uri=config.uri, token=config.token)
return self._index_from_vector_store(vector_store=vector_store, config=config, **kwargs)
def _create_chroma(self, config: ChromaIndexConfig, **kwargs) -> VectorStoreIndex:
db = chromadb.PersistentClient(str(config.persist_path))
chroma_collection = db.get_or_create_collection(config.collection_name, metadata=config.metadata)

View file

@ -12,6 +12,7 @@ from llama_index.core.vector_stores.types import BasePydanticVectorStore
from llama_index.vector_stores.chroma import ChromaVectorStore
from llama_index.vector_stores.elasticsearch import ElasticsearchStore
from llama_index.vector_stores.faiss import FaissVectorStore
from llama_index.vector_stores.milvus import MilvusVectorStore
from metagpt.rag.factories.base import ConfigBasedFactory
from metagpt.rag.retrievers.base import RAGRetriever
@ -20,13 +21,14 @@ from metagpt.rag.retrievers.chroma_retriever import ChromaRetriever
from metagpt.rag.retrievers.es_retriever import ElasticsearchRetriever
from metagpt.rag.retrievers.faiss_retriever import FAISSRetriever
from metagpt.rag.retrievers.hybrid_retriever import SimpleHybridRetriever
from metagpt.rag.retrievers.milvus_retriever import MilvusRetriever
from metagpt.rag.schema import (
BaseRetrieverConfig,
BM25RetrieverConfig,
ChromaRetrieverConfig,
ElasticsearchKeywordRetrieverConfig,
ElasticsearchRetrieverConfig,
FAISSRetrieverConfig,
FAISSRetrieverConfig, MilvusRetrieverConfig,
)
@ -56,6 +58,7 @@ class RetrieverFactory(ConfigBasedFactory):
ChromaRetrieverConfig: self._create_chroma_retriever,
ElasticsearchRetrieverConfig: self._create_es_retriever,
ElasticsearchKeywordRetrieverConfig: self._create_es_retriever,
MilvusRetrieverConfig: self._create_milvus_retriever,
}
super().__init__(creators)
@ -76,6 +79,11 @@ class RetrieverFactory(ConfigBasedFactory):
return index.as_retriever()
def _create_milvus_retriever(self, config: MilvusRetrieverConfig, **kwargs) -> MilvusRetriever:
config.index = self._build_milvus_index(config, **kwargs)
return MilvusRetriever(**config.model_dump())
def _create_faiss_retriever(self, config: FAISSRetrieverConfig, **kwargs) -> FAISSRetriever:
config.index = self._build_faiss_index(config, **kwargs)
@ -128,6 +136,12 @@ class RetrieverFactory(ConfigBasedFactory):
return self._build_index_from_vector_store(config, vector_store, **kwargs)
@get_or_build_index
def _build_milvus_index(self, config: MilvusRetrieverConfig, **kwargs) -> VectorStoreIndex:
vector_store = MilvusVectorStore(uri=config.uri, collection_name=config.collection_name, token=config.token)
return self._build_index_from_vector_store(config, vector_store, **kwargs)
@get_or_build_index
def _build_es_index(self, config: ElasticsearchRetrieverConfig, **kwargs) -> VectorStoreIndex:
vector_store = ElasticsearchStore(**config.store_config.model_dump())

View file

@ -0,0 +1,17 @@
"""Milvus retriever."""
from llama_index.core.retrievers import VectorIndexRetriever
from llama_index.core.schema import BaseNode
class MilvusRetriever(VectorIndexRetriever):
"""Milvus retriever."""
def add_nodes(self, nodes: list[BaseNode], **kwargs) -> None:
"""Support add nodes."""
self._index.insert_nodes(nodes, **kwargs)
def persist(self, persist_dir: str, **kwargs) -> None:
"""Support persist.
Milvus automatically saves, so there is no need to implement."""

View file

@ -62,6 +62,17 @@ class BM25RetrieverConfig(IndexRetrieverConfig):
_no_embedding: bool = PrivateAttr(default=True)
class MilvusRetrieverConfig(IndexRetrieverConfig):
"""Config for Milvus-based retrievers."""
uri: str = Field(default="./milvus_local.db", description="The directory to save data.")
collection_name: str = Field(default="metagpt", description="The name of the collection.")
token: str = Field(default=None, description="The token for Milvus")
metadata: Optional[CollectionMetadata] = Field(
default=None, description="Optional metadata to associate with the collection"
)
class ChromaRetrieverConfig(IndexRetrieverConfig):
"""Config for Chroma-based retrievers."""
@ -169,6 +180,16 @@ class ChromaIndexConfig(VectorIndexConfig):
default=None, description="Optional metadata to associate with the collection"
)
class MilvusIndexConfig(VectorIndexConfig):
"""Config for milvus-based index."""
collection_name: str = Field(default="metagpt", description="The name of the collection.")
uri: str = Field(default="./milvus_local.db", description="The uri of the index.")
token: Optional[str] = Field(default=None, description="The token of the index.")
metadata: Optional[CollectionMetadata] = Field(
default=None, description="Optional metadata to associate with the collection"
)
class BM25IndexConfig(BaseIndexConfig):
"""Config for bm25-based index."""

View file

@ -7,7 +7,7 @@ from metagpt.rag.schema import (
ChromaIndexConfig,
ElasticsearchIndexConfig,
ElasticsearchStoreConfig,
FAISSIndexConfig,
FAISSIndexConfig, MilvusIndexConfig,
)
@ -20,6 +20,10 @@ class TestRAGIndexFactory:
def faiss_config(self):
return FAISSIndexConfig(persist_path="")
@pytest.fixture
def milvus_config(self):
return MilvusIndexConfig(uri="", collection_name="")
@pytest.fixture
def chroma_config(self):
return ChromaIndexConfig(persist_path="", collection_name="")
@ -65,6 +69,17 @@ class TestRAGIndexFactory:
):
self.index_factory.get_index(bm25_config, embed_model=mock_embedding)
def test_create_milvus_index(self, mocker, milvus_config, mock_from_vector_store, mock_embedding):
# Mock
mock_milvus_store = mocker.patch("metagpt.rag.factories.index.MilvusVectorStore")
# Exec
self.index_factory.get_index(milvus_config, embed_model=mock_embedding)
# Assert
mock_milvus_store.assert_called_once()
def test_create_chroma_index(self, mocker, chroma_config, mock_from_vector_store, mock_embedding):
# Mock
mock_chroma_db = mocker.patch("metagpt.rag.factories.index.chromadb.PersistentClient")

View file

@ -5,6 +5,7 @@ from llama_index.core.embeddings import MockEmbedding
from llama_index.core.schema import TextNode
from llama_index.vector_stores.chroma import ChromaVectorStore
from llama_index.vector_stores.elasticsearch import ElasticsearchStore
from llama_index.vector_stores.milvus import MilvusVectorStore
from metagpt.rag.factories.retriever import RetrieverFactory
from metagpt.rag.retrievers.bm25_retriever import DynamicBM25Retriever
@ -12,12 +13,14 @@ from metagpt.rag.retrievers.chroma_retriever import ChromaRetriever
from metagpt.rag.retrievers.es_retriever import ElasticsearchRetriever
from metagpt.rag.retrievers.faiss_retriever import FAISSRetriever
from metagpt.rag.retrievers.hybrid_retriever import SimpleHybridRetriever
from metagpt.rag.retrievers.milvus_retriever import MilvusRetriever
from metagpt.rag.schema import (
BM25RetrieverConfig,
ChromaRetrieverConfig,
ElasticsearchRetrieverConfig,
ElasticsearchStoreConfig,
FAISSRetrieverConfig,
MilvusRetrieverConfig,
)
@ -41,6 +44,10 @@ class TestRetrieverFactory:
def mock_chroma_vector_store(self, mocker):
return mocker.MagicMock(spec=ChromaVectorStore)
@pytest.fixture
def mock_milvus_vector_store(self, mocker):
return mocker.MagicMock(spec=MilvusVectorStore)
@pytest.fixture
def mock_es_vector_store(self, mocker):
return mocker.MagicMock(spec=ElasticsearchStore)
@ -91,6 +98,14 @@ class TestRetrieverFactory:
assert isinstance(retriever, ChromaRetriever)
def test_get_retriever_with_milvus_config(self, mocker, mock_milvus_vector_store, mock_embedding):
mock_config = MilvusRetrieverConfig(uri="/path/to/milvus", collection_name="test_collection")
mocker.patch("metagpt.rag.factories.retriever.MilvusVectorStore", return_value=mock_milvus_vector_store)
retriever = self.retriever_factory.get_retriever(configs=[mock_config], nodes=[], embed_model=mock_embedding)
assert isinstance(retriever, MilvusRetriever)
def test_get_retriever_with_es_config(self, mocker, mock_es_vector_store, mock_embedding):
mock_config = ElasticsearchRetrieverConfig(store_config=ElasticsearchStoreConfig())
mocker.patch("metagpt.rag.factories.retriever.ElasticsearchStore", return_value=mock_es_vector_store)