From 4d92fdcec97f3f063d46e14e8e9a1329695fb3ad Mon Sep 17 00:00:00 2001 From: ChengZi Date: Wed, 25 Sep 2024 11:47:28 +0800 Subject: [PATCH] lazy dependency for milvus Signed-off-by: ChengZi --- metagpt/document_store/milvus_store.py | 71 ++++++------------- .../document_store/test_milvus_store.py | 10 +-- 2 files changed, 28 insertions(+), 53 deletions(-) diff --git a/metagpt/document_store/milvus_store.py b/metagpt/document_store/milvus_store.py index 9d5de93cd..e4d6d985e 100644 --- a/metagpt/document_store/milvus_store.py +++ b/metagpt/document_store/milvus_store.py @@ -1,9 +1,9 @@ from dataclasses import dataclass -from typing import List, Dict, Any, Optional -from pymilvus import MilvusClient, DataType +from typing import Any, Dict, List, Optional from metagpt.document_store.base_store import BaseStore + @dataclass class MilvusConnection: """ @@ -18,19 +18,17 @@ class MilvusConnection: class MilvusStore(BaseStore): def __init__(self, connect: MilvusConnection): + try: + from pymilvus import MilvusClient + except ImportError: + raise Exception("Please install pymilvus first.") if not connect.uri: raise Exception("please check MilvusConnection, uri must be set.") - self.client = MilvusClient( - uri=connect.uri, - token=connect.token - ) + self.client = MilvusClient(uri=connect.uri, token=connect.token) + + def create_collection(self, collection_name: str, dim: int, enable_dynamic_schema: bool = True): + from pymilvus import DataType - def create_collection( - self, - collection_name: str, - dim: int, - enable_dynamic_schema: bool = True - ): if self.client.has_collection(collection_name=collection_name): self.client.drop_collection(collection_name=collection_name) @@ -42,17 +40,13 @@ class MilvusStore(BaseStore): schema.add_field(field_name="vector", datatype=DataType.FLOAT_VECTOR, dim=dim) index_params = self.client.prepare_index_params() - index_params.add_index( - field_name="vector", - index_type="AUTOINDEX", - metric_type="COSINE" - ) + index_params.add_index(field_name="vector", index_type="AUTOINDEX", metric_type="COSINE") self.client.create_collection( collection_name=collection_name, schema=schema, index_params=index_params, - enable_dynamic_schema=enable_dynamic_schema + enable_dynamic_schema=enable_dynamic_schema, ) @staticmethod @@ -61,9 +55,9 @@ class MilvusStore(BaseStore): filter_expression = f'{key} == "{value}"' else: if isinstance(value, list): - filter_expression = f'{key} in {value}' + filter_expression = f"{key} in {value}" else: - filter_expression = f'{key} == {value}' + filter_expression = f"{key} == {value}" return filter_expression @@ -71,14 +65,11 @@ class MilvusStore(BaseStore): self, collection_name: str, query: List[float], - filter: Dict[str, str | int | list[int]] = None, + filter: Dict = None, limit: int = 10, output_fields: Optional[List[str]] = None, ) -> List[dict]: - filter_expression = '' - - for key, value in filter.items(): - filter_expression += f'{self.build_filter(key, value)} and ' + filter_expression = " and ".join([self.build_filter(key, value) for key, value in filter.items()]) print(filter_expression) res = self.client.search( @@ -91,34 +82,18 @@ class MilvusStore(BaseStore): return res - def add( - self, - collection_name: str, - _ids: List[str], - vector: List[List[float]], - metadata: List[Dict[str, Any]] - ): + def add(self, collection_name: str, _ids: List[str], vector: List[List[float]], metadata: List[Dict[str, Any]]): data = dict() for i, id in enumerate(_ids): - data['id'] = id - data['vector'] = vector[i] - data['metadata'] = metadata[i] + data["id"] = id + data["vector"] = vector[i] + data["metadata"] = metadata[i] - self.client.upsert( - collection_name=collection_name, - data=data - ) + self.client.upsert(collection_name=collection_name, data=data) - def delete( - self, - collection_name: str, - _ids: List[str] - ): - self.client.delete( - collection_name=collection_name, - ids=_ids - ) + def delete(self, collection_name: str, _ids: List[str]): + self.client.delete(collection_name=collection_name, ids=_ids) def write(self, *args, **kwargs): pass diff --git a/tests/metagpt/document_store/test_milvus_store.py b/tests/metagpt/document_store/test_milvus_store.py index 7cfd31381..93d4187f9 100644 --- a/tests/metagpt/document_store/test_milvus_store.py +++ b/tests/metagpt/document_store/test_milvus_store.py @@ -1,4 +1,7 @@ import random + +import pytest + from metagpt.document_store.milvus_store import MilvusConnection, MilvusStore seed_value = 42 @@ -19,6 +22,7 @@ def assert_almost_equal(actual, expected): assert abs(actual - expected) <= delta, f"{actual} is not within {delta} of {expected}" +@pytest.mark.skip() # Skip because the pymilvus dependency is not installed by default def test_milvus_store(): milvus_connection = MilvusConnection(uri="./milvus_local.db") milvus_store = MilvusStore(milvus_connection) @@ -33,11 +37,7 @@ def test_milvus_store(): first_result = search_results[0] assert first_result["id"] == "doc_0" - search_results_with_filter = milvus_store.search( - collection_name, - query=[1.0] * 8, - filter={"rand_number": 1} - ) + search_results_with_filter = milvus_store.search(collection_name, query=[1.0] * 8, filter={"rand_number": 1}) assert len(search_results_with_filter) > 0 assert search_results_with_filter[0]["id"] == "doc_1"