diff --git a/metagpt/rag/factories/retriever.py b/metagpt/rag/factories/retriever.py index c3d3a4f80..3342b8905 100644 --- a/metagpt/rag/factories/retriever.py +++ b/metagpt/rag/factories/retriever.py @@ -28,7 +28,8 @@ from metagpt.rag.schema import ( ChromaRetrieverConfig, ElasticsearchKeywordRetrieverConfig, ElasticsearchRetrieverConfig, - FAISSRetrieverConfig, MilvusRetrieverConfig, + FAISSRetrieverConfig, + MilvusRetrieverConfig, ) @@ -138,7 +139,7 @@ class RetrieverFactory(ConfigBasedFactory): @get_or_build_index def _build_milvus_index(self, config: MilvusRetrieverConfig, **kwargs) -> VectorStoreIndex: - vector_store = MilvusVectorStore(uri=config.uri, collection_name=config.collection_name, token=config.token) + vector_store = MilvusVectorStore(uri=config.uri, collection_name=config.collection_name, token=config.token, dim=config.dimensions) return self._build_index_from_vector_store(config, vector_store, **kwargs) diff --git a/metagpt/rag/schema.py b/metagpt/rag/schema.py index 89e189235..e4d97068d 100644 --- a/metagpt/rag/schema.py +++ b/metagpt/rag/schema.py @@ -8,7 +8,7 @@ from llama_index.core.embeddings import BaseEmbedding from llama_index.core.indices.base import BaseIndex from llama_index.core.schema import TextNode from llama_index.core.vector_stores.types import VectorStoreQueryMode -from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, model_validator +from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, model_validator, validator from metagpt.config2 import config from metagpt.configs.embedding_config import EmbeddingType @@ -71,6 +71,25 @@ class MilvusRetrieverConfig(IndexRetrieverConfig): metadata: Optional[CollectionMetadata] = Field( default=None, description="Optional metadata to associate with the collection" ) + dimensions: int = Field(default=0, description="Dimensionality of the vectors for Milvus index construction.") + + _embedding_type_to_dimensions: ClassVar[dict[EmbeddingType, int]] = { + EmbeddingType.GEMINI: 768, + EmbeddingType.OLLAMA: 4096, + } + + @model_validator(mode="after") + def check_dimensions(self): + if self.dimensions == 0: + self.dimensions = config.embedding.dimensions or self._embedding_type_to_dimensions.get( + config.embedding.api_type, 1536 + ) + if not config.embedding.dimensions and config.embedding.api_type not in self._embedding_type_to_dimensions: + logger.warning( + f"You didn't set dimensions in config when using {config.embedding.api_type}, default to 1536" + ) + + return self class ChromaRetrieverConfig(IndexRetrieverConfig): diff --git a/tests/metagpt/rag/factories/test_index.py b/tests/metagpt/rag/factories/test_index.py index 5d8711f9f..9861e1242 100644 --- a/tests/metagpt/rag/factories/test_index.py +++ b/tests/metagpt/rag/factories/test_index.py @@ -69,7 +69,6 @@ class TestRAGIndexFactory: ): self.index_factory.get_index(bm25_config, embed_model=mock_embedding) - def test_create_milvus_index(self, mocker, milvus_config, mock_from_vector_store, mock_embedding): # Mock mock_milvus_store = mocker.patch("metagpt.rag.factories.index.MilvusVectorStore") diff --git a/tests/metagpt/rag/factories/test_retriever.py b/tests/metagpt/rag/factories/test_retriever.py index 149e4b172..b808de26e 100644 --- a/tests/metagpt/rag/factories/test_retriever.py +++ b/tests/metagpt/rag/factories/test_retriever.py @@ -99,7 +99,7 @@ class TestRetrieverFactory: assert isinstance(retriever, ChromaRetriever) def test_get_retriever_with_milvus_config(self, mocker, mock_milvus_vector_store, mock_embedding): - mock_config = MilvusRetrieverConfig(uri="/path/to/milvus", collection_name="test_collection") + mock_config = MilvusRetrieverConfig(uri="/path/to/milvus.db", collection_name="test_collection") mocker.patch("metagpt.rag.factories.retriever.MilvusVectorStore", return_value=mock_milvus_vector_store) retriever = self.retriever_factory.get_retriever(configs=[mock_config], nodes=[], embed_model=mock_embedding)