Merge pull request #1457 from Jacksonxhx/milvus

Integrated Milvus with MetaGPT
This commit is contained in:
Alexander Wu 2024-10-15 19:40:20 +08:00 committed by GitHub
commit 32d416bac9
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 261 additions and 4 deletions

View file

@ -0,0 +1,48 @@
import random
import pytest
from metagpt.document_store.milvus_store import MilvusConnection, MilvusStore
seed_value = 42
random.seed(seed_value)
vectors = [[random.random() for _ in range(8)] for _ in range(10)]
ids = [f"doc_{i}" for i in range(10)]
metadata = [{"color": "red", "rand_number": i % 10} for i in range(10)]
def assert_almost_equal(actual, expected):
delta = 1e-10
if isinstance(expected, list):
assert len(actual) == len(expected)
for ac, exp in zip(actual, expected):
assert abs(ac - exp) <= delta, f"{ac} is not within {delta} of {exp}"
else:
assert abs(actual - expected) <= delta, f"{actual} is not within {delta} of {expected}"
@pytest.mark.skip() # Skip because the pymilvus dependency is not installed by default
def test_milvus_store():
milvus_connection = MilvusConnection(uri="./milvus_local.db")
milvus_store = MilvusStore(milvus_connection)
collection_name = "TestCollection"
milvus_store.create_collection(collection_name, dim=8)
milvus_store.add(collection_name, ids, vectors, metadata)
search_results = milvus_store.search(collection_name, query=[1.0] * 8)
assert len(search_results) > 0
first_result = search_results[0]
assert first_result["id"] == "doc_0"
search_results_with_filter = milvus_store.search(collection_name, query=[1.0] * 8, filter={"rand_number": 1})
assert len(search_results_with_filter) > 0
assert search_results_with_filter[0]["id"] == "doc_1"
milvus_store.delete(collection_name, _ids=["doc_0"])
deleted_results = milvus_store.search(collection_name, query=[1.0] * 8, limit=1)
assert deleted_results[0]["id"] != "doc_0"
milvus_store.client.drop_collection(collection_name)

View file

@ -7,7 +7,7 @@ from metagpt.rag.schema import (
ChromaIndexConfig,
ElasticsearchIndexConfig,
ElasticsearchStoreConfig,
FAISSIndexConfig,
FAISSIndexConfig, MilvusIndexConfig,
)
@ -20,6 +20,10 @@ class TestRAGIndexFactory:
def faiss_config(self):
return FAISSIndexConfig(persist_path="")
@pytest.fixture
def milvus_config(self):
return MilvusIndexConfig(uri="", collection_name="")
@pytest.fixture
def chroma_config(self):
return ChromaIndexConfig(persist_path="", collection_name="")
@ -65,6 +69,16 @@ class TestRAGIndexFactory:
):
self.index_factory.get_index(bm25_config, embed_model=mock_embedding)
def test_create_milvus_index(self, mocker, milvus_config, mock_from_vector_store, mock_embedding):
# Mock
mock_milvus_store = mocker.patch("metagpt.rag.factories.index.MilvusVectorStore")
# Exec
self.index_factory.get_index(milvus_config, embed_model=mock_embedding)
# Assert
mock_milvus_store.assert_called_once()
def test_create_chroma_index(self, mocker, chroma_config, mock_from_vector_store, mock_embedding):
# Mock
mock_chroma_db = mocker.patch("metagpt.rag.factories.index.chromadb.PersistentClient")

View file

@ -5,6 +5,7 @@ from llama_index.core.embeddings import MockEmbedding
from llama_index.core.schema import TextNode
from llama_index.vector_stores.chroma import ChromaVectorStore
from llama_index.vector_stores.elasticsearch import ElasticsearchStore
from llama_index.vector_stores.milvus import MilvusVectorStore
from metagpt.rag.factories.retriever import RetrieverFactory
from metagpt.rag.retrievers.bm25_retriever import DynamicBM25Retriever
@ -12,12 +13,14 @@ from metagpt.rag.retrievers.chroma_retriever import ChromaRetriever
from metagpt.rag.retrievers.es_retriever import ElasticsearchRetriever
from metagpt.rag.retrievers.faiss_retriever import FAISSRetriever
from metagpt.rag.retrievers.hybrid_retriever import SimpleHybridRetriever
from metagpt.rag.retrievers.milvus_retriever import MilvusRetriever
from metagpt.rag.schema import (
BM25RetrieverConfig,
ChromaRetrieverConfig,
ElasticsearchRetrieverConfig,
ElasticsearchStoreConfig,
FAISSRetrieverConfig,
MilvusRetrieverConfig,
)
@ -41,6 +44,10 @@ class TestRetrieverFactory:
def mock_chroma_vector_store(self, mocker):
return mocker.MagicMock(spec=ChromaVectorStore)
@pytest.fixture
def mock_milvus_vector_store(self, mocker):
return mocker.MagicMock(spec=MilvusVectorStore)
@pytest.fixture
def mock_es_vector_store(self, mocker):
return mocker.MagicMock(spec=ElasticsearchStore)
@ -91,6 +98,14 @@ class TestRetrieverFactory:
assert isinstance(retriever, ChromaRetriever)
def test_get_retriever_with_milvus_config(self, mocker, mock_milvus_vector_store, mock_embedding):
mock_config = MilvusRetrieverConfig(uri="/path/to/milvus.db", collection_name="test_collection")
mocker.patch("metagpt.rag.factories.retriever.MilvusVectorStore", return_value=mock_milvus_vector_store)
retriever = self.retriever_factory.get_retriever(configs=[mock_config], nodes=[], embed_model=mock_embedding)
assert isinstance(retriever, MilvusRetriever)
def test_get_retriever_with_es_config(self, mocker, mock_es_vector_store, mock_embedding):
mock_config = ElasticsearchRetrieverConfig(store_config=ElasticsearchStoreConfig())
mocker.patch("metagpt.rag.factories.retriever.ElasticsearchStore", return_value=mock_es_vector_store)