lazy dependency for milvus

Signed-off-by: ChengZi <chen.zhang@zilliz.com>
This commit is contained in:
ChengZi 2024-09-25 11:47:28 +08:00
parent e5f037f86d
commit 4d92fdcec9
2 changed files with 28 additions and 53 deletions

View file

@ -1,9 +1,9 @@
from dataclasses import dataclass
from typing import List, Dict, Any, Optional
from pymilvus import MilvusClient, DataType
from typing import Any, Dict, List, Optional
from metagpt.document_store.base_store import BaseStore
@dataclass
class MilvusConnection:
"""
@ -18,19 +18,17 @@ class MilvusConnection:
class MilvusStore(BaseStore):
def __init__(self, connect: MilvusConnection):
try:
from pymilvus import MilvusClient
except ImportError:
raise Exception("Please install pymilvus first.")
if not connect.uri:
raise Exception("please check MilvusConnection, uri must be set.")
self.client = MilvusClient(
uri=connect.uri,
token=connect.token
)
self.client = MilvusClient(uri=connect.uri, token=connect.token)
def create_collection(self, collection_name: str, dim: int, enable_dynamic_schema: bool = True):
from pymilvus import DataType
def create_collection(
self,
collection_name: str,
dim: int,
enable_dynamic_schema: bool = True
):
if self.client.has_collection(collection_name=collection_name):
self.client.drop_collection(collection_name=collection_name)
@ -42,17 +40,13 @@ class MilvusStore(BaseStore):
schema.add_field(field_name="vector", datatype=DataType.FLOAT_VECTOR, dim=dim)
index_params = self.client.prepare_index_params()
index_params.add_index(
field_name="vector",
index_type="AUTOINDEX",
metric_type="COSINE"
)
index_params.add_index(field_name="vector", index_type="AUTOINDEX", metric_type="COSINE")
self.client.create_collection(
collection_name=collection_name,
schema=schema,
index_params=index_params,
enable_dynamic_schema=enable_dynamic_schema
enable_dynamic_schema=enable_dynamic_schema,
)
@staticmethod
@ -61,9 +55,9 @@ class MilvusStore(BaseStore):
filter_expression = f'{key} == "{value}"'
else:
if isinstance(value, list):
filter_expression = f'{key} in {value}'
filter_expression = f"{key} in {value}"
else:
filter_expression = f'{key} == {value}'
filter_expression = f"{key} == {value}"
return filter_expression
@ -71,14 +65,11 @@ class MilvusStore(BaseStore):
self,
collection_name: str,
query: List[float],
filter: Dict[str, str | int | list[int]] = None,
filter: Dict = None,
limit: int = 10,
output_fields: Optional[List[str]] = None,
) -> List[dict]:
filter_expression = ''
for key, value in filter.items():
filter_expression += f'{self.build_filter(key, value)} and '
filter_expression = " and ".join([self.build_filter(key, value) for key, value in filter.items()])
print(filter_expression)
res = self.client.search(
@ -91,34 +82,18 @@ class MilvusStore(BaseStore):
return res
def add(
self,
collection_name: str,
_ids: List[str],
vector: List[List[float]],
metadata: List[Dict[str, Any]]
):
def add(self, collection_name: str, _ids: List[str], vector: List[List[float]], metadata: List[Dict[str, Any]]):
data = dict()
for i, id in enumerate(_ids):
data['id'] = id
data['vector'] = vector[i]
data['metadata'] = metadata[i]
data["id"] = id
data["vector"] = vector[i]
data["metadata"] = metadata[i]
self.client.upsert(
collection_name=collection_name,
data=data
)
self.client.upsert(collection_name=collection_name, data=data)
def delete(
self,
collection_name: str,
_ids: List[str]
):
self.client.delete(
collection_name=collection_name,
ids=_ids
)
def delete(self, collection_name: str, _ids: List[str]):
self.client.delete(collection_name=collection_name, ids=_ids)
def write(self, *args, **kwargs):
pass

View file

@ -1,4 +1,7 @@
import random
import pytest
from metagpt.document_store.milvus_store import MilvusConnection, MilvusStore
seed_value = 42
@ -19,6 +22,7 @@ def assert_almost_equal(actual, expected):
assert abs(actual - expected) <= delta, f"{actual} is not within {delta} of {expected}"
@pytest.mark.skip() # Skip because the pymilvus dependency is not installed by default
def test_milvus_store():
milvus_connection = MilvusConnection(uri="./milvus_local.db")
milvus_store = MilvusStore(milvus_connection)
@ -33,11 +37,7 @@ def test_milvus_store():
first_result = search_results[0]
assert first_result["id"] == "doc_0"
search_results_with_filter = milvus_store.search(
collection_name,
query=[1.0] * 8,
filter={"rand_number": 1}
)
search_results_with_filter = milvus_store.search(collection_name, query=[1.0] * 8, filter={"rand_number": 1})
assert len(search_results_with_filter) > 0
assert search_results_with_filter[0]["id"] == "doc_1"