Merge pull request #1457 from Jacksonxhx/milvus

Integrated Milvus with MetaGPT
This commit is contained in:
Alexander Wu 2024-10-15 19:40:20 +08:00 committed by GitHub
commit 32d416bac9
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 261 additions and 4 deletions

View file

@ -0,0 +1,99 @@
from dataclasses import dataclass
from typing import Any, Dict, List, Optional
from metagpt.document_store.base_store import BaseStore
@dataclass
class MilvusConnection:
"""
Args:
uri: milvus url
token: milvus token
"""
uri: str = None
token: str = None
class MilvusStore(BaseStore):
def __init__(self, connect: MilvusConnection):
try:
from pymilvus import MilvusClient
except ImportError:
raise Exception("Please install pymilvus first.")
if not connect.uri:
raise Exception("please check MilvusConnection, uri must be set.")
self.client = MilvusClient(uri=connect.uri, token=connect.token)
def create_collection(self, collection_name: str, dim: int, enable_dynamic_schema: bool = True):
from pymilvus import DataType
if self.client.has_collection(collection_name=collection_name):
self.client.drop_collection(collection_name=collection_name)
schema = self.client.create_schema(
auto_id=False,
enable_dynamic_field=False,
)
schema.add_field(field_name="id", datatype=DataType.VARCHAR, is_primary=True, max_length=36)
schema.add_field(field_name="vector", datatype=DataType.FLOAT_VECTOR, dim=dim)
index_params = self.client.prepare_index_params()
index_params.add_index(field_name="vector", index_type="AUTOINDEX", metric_type="COSINE")
self.client.create_collection(
collection_name=collection_name,
schema=schema,
index_params=index_params,
enable_dynamic_schema=enable_dynamic_schema,
)
@staticmethod
def build_filter(key, value) -> str:
if isinstance(value, str):
filter_expression = f'{key} == "{value}"'
else:
if isinstance(value, list):
filter_expression = f"{key} in {value}"
else:
filter_expression = f"{key} == {value}"
return filter_expression
def search(
self,
collection_name: str,
query: List[float],
filter: Dict = None,
limit: int = 10,
output_fields: Optional[List[str]] = None,
) -> List[dict]:
filter_expression = " and ".join([self.build_filter(key, value) for key, value in filter.items()])
print(filter_expression)
res = self.client.search(
collection_name=collection_name,
data=[query],
filter=filter_expression,
limit=limit,
output_fields=output_fields,
)[0]
return res
def add(self, collection_name: str, _ids: List[str], vector: List[List[float]], metadata: List[Dict[str, Any]]):
data = dict()
for i, id in enumerate(_ids):
data["id"] = id
data["vector"] = vector[i]
data["metadata"] = metadata[i]
self.client.upsert(collection_name=collection_name, data=data)
def delete(self, collection_name: str, _ids: List[str]):
self.client.delete(collection_name=collection_name, ids=_ids)
def write(self, *args, **kwargs):
pass

View file

@ -8,6 +8,7 @@ from llama_index.core.vector_stores.types import BasePydanticVectorStore
from llama_index.vector_stores.chroma import ChromaVectorStore
from llama_index.vector_stores.elasticsearch import ElasticsearchStore
from llama_index.vector_stores.faiss import FaissVectorStore
from llama_index.vector_stores.milvus import MilvusVectorStore
from metagpt.rag.factories.base import ConfigBasedFactory
from metagpt.rag.schema import (
@ -17,6 +18,7 @@ from metagpt.rag.schema import (
ElasticsearchIndexConfig,
ElasticsearchKeywordIndexConfig,
FAISSIndexConfig,
MilvusIndexConfig,
)
@ -28,6 +30,7 @@ class RAGIndexFactory(ConfigBasedFactory):
BM25IndexConfig: self._create_bm25,
ElasticsearchIndexConfig: self._create_es,
ElasticsearchKeywordIndexConfig: self._create_es,
MilvusIndexConfig: self._create_milvus
}
super().__init__(creators)
@ -46,6 +49,11 @@ class RAGIndexFactory(ConfigBasedFactory):
return self._index_from_storage(storage_context=storage_context, config=config, **kwargs)
def _create_milvus(self, config: MilvusIndexConfig, **kwargs) -> VectorStoreIndex:
vector_store = MilvusVectorStore(collection_name=config.collection_name, uri=config.uri, token=config.token)
return self._index_from_vector_store(vector_store=vector_store, config=config, **kwargs)
def _create_chroma(self, config: ChromaIndexConfig, **kwargs) -> VectorStoreIndex:
db = chromadb.PersistentClient(str(config.persist_path))
chroma_collection = db.get_or_create_collection(config.collection_name, metadata=config.metadata)

View file

@ -12,6 +12,7 @@ from llama_index.core.vector_stores.types import BasePydanticVectorStore
from llama_index.vector_stores.chroma import ChromaVectorStore
from llama_index.vector_stores.elasticsearch import ElasticsearchStore
from llama_index.vector_stores.faiss import FaissVectorStore
from llama_index.vector_stores.milvus import MilvusVectorStore
from metagpt.rag.factories.base import ConfigBasedFactory
from metagpt.rag.retrievers.base import RAGRetriever
@ -20,6 +21,7 @@ from metagpt.rag.retrievers.chroma_retriever import ChromaRetriever
from metagpt.rag.retrievers.es_retriever import ElasticsearchRetriever
from metagpt.rag.retrievers.faiss_retriever import FAISSRetriever
from metagpt.rag.retrievers.hybrid_retriever import SimpleHybridRetriever
from metagpt.rag.retrievers.milvus_retriever import MilvusRetriever
from metagpt.rag.schema import (
BaseRetrieverConfig,
BM25RetrieverConfig,
@ -27,6 +29,7 @@ from metagpt.rag.schema import (
ElasticsearchKeywordRetrieverConfig,
ElasticsearchRetrieverConfig,
FAISSRetrieverConfig,
MilvusRetrieverConfig,
)
@ -56,6 +59,7 @@ class RetrieverFactory(ConfigBasedFactory):
ChromaRetrieverConfig: self._create_chroma_retriever,
ElasticsearchRetrieverConfig: self._create_es_retriever,
ElasticsearchKeywordRetrieverConfig: self._create_es_retriever,
MilvusRetrieverConfig: self._create_milvus_retriever,
}
super().__init__(creators)
@ -76,6 +80,11 @@ class RetrieverFactory(ConfigBasedFactory):
return index.as_retriever()
def _create_milvus_retriever(self, config: MilvusRetrieverConfig, **kwargs) -> MilvusRetriever:
config.index = self._build_milvus_index(config, **kwargs)
return MilvusRetriever(**config.model_dump())
def _create_faiss_retriever(self, config: FAISSRetrieverConfig, **kwargs) -> FAISSRetriever:
config.index = self._build_faiss_index(config, **kwargs)
@ -128,6 +137,12 @@ class RetrieverFactory(ConfigBasedFactory):
return self._build_index_from_vector_store(config, vector_store, **kwargs)
@get_or_build_index
def _build_milvus_index(self, config: MilvusRetrieverConfig, **kwargs) -> VectorStoreIndex:
vector_store = MilvusVectorStore(uri=config.uri, collection_name=config.collection_name, token=config.token, dim=config.dimensions)
return self._build_index_from_vector_store(config, vector_store, **kwargs)
@get_or_build_index
def _build_es_index(self, config: ElasticsearchRetrieverConfig, **kwargs) -> VectorStoreIndex:
vector_store = ElasticsearchStore(**config.store_config.model_dump())

View file

@ -0,0 +1,17 @@
"""Milvus retriever."""
from llama_index.core.retrievers import VectorIndexRetriever
from llama_index.core.schema import BaseNode
class MilvusRetriever(VectorIndexRetriever):
"""Milvus retriever."""
def add_nodes(self, nodes: list[BaseNode], **kwargs) -> None:
"""Support add nodes."""
self._index.insert_nodes(nodes, **kwargs)
def persist(self, persist_dir: str, **kwargs) -> None:
"""Support persist.
Milvus automatically saves, so there is no need to implement."""

View file

@ -8,7 +8,7 @@ from llama_index.core.embeddings import BaseEmbedding
from llama_index.core.indices.base import BaseIndex
from llama_index.core.schema import TextNode
from llama_index.core.vector_stores.types import VectorStoreQueryMode
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, model_validator
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, model_validator, validator
from metagpt.config2 import config
from metagpt.configs.embedding_config import EmbeddingType
@ -62,6 +62,36 @@ class BM25RetrieverConfig(IndexRetrieverConfig):
_no_embedding: bool = PrivateAttr(default=True)
class MilvusRetrieverConfig(IndexRetrieverConfig):
"""Config for Milvus-based retrievers."""
uri: str = Field(default="./milvus_local.db", description="The directory to save data.")
collection_name: str = Field(default="metagpt", description="The name of the collection.")
token: str = Field(default=None, description="The token for Milvus")
metadata: Optional[CollectionMetadata] = Field(
default=None, description="Optional metadata to associate with the collection"
)
dimensions: int = Field(default=0, description="Dimensionality of the vectors for Milvus index construction.")
_embedding_type_to_dimensions: ClassVar[dict[EmbeddingType, int]] = {
EmbeddingType.GEMINI: 768,
EmbeddingType.OLLAMA: 4096,
}
@model_validator(mode="after")
def check_dimensions(self):
if self.dimensions == 0:
self.dimensions = config.embedding.dimensions or self._embedding_type_to_dimensions.get(
config.embedding.api_type, 1536
)
if not config.embedding.dimensions and config.embedding.api_type not in self._embedding_type_to_dimensions:
logger.warning(
f"You didn't set dimensions in config when using {config.embedding.api_type}, default to 1536"
)
return self
class ChromaRetrieverConfig(IndexRetrieverConfig):
"""Config for Chroma-based retrievers."""
@ -169,6 +199,16 @@ class ChromaIndexConfig(VectorIndexConfig):
default=None, description="Optional metadata to associate with the collection"
)
class MilvusIndexConfig(VectorIndexConfig):
"""Config for milvus-based index."""
collection_name: str = Field(default="metagpt", description="The name of the collection.")
uri: str = Field(default="./milvus_local.db", description="The uri of the index.")
token: Optional[str] = Field(default=None, description="The token of the index.")
metadata: Optional[CollectionMetadata] = Field(
default=None, description="Optional metadata to associate with the collection"
)
class BM25IndexConfig(BaseIndexConfig):
"""Config for bm25-based index."""