diff --git a/metagpt/document_store/milvus_store.py b/metagpt/document_store/milvus_store.py new file mode 100644 index 000000000..9d5de93cd --- /dev/null +++ b/metagpt/document_store/milvus_store.py @@ -0,0 +1,124 @@ +from dataclasses import dataclass +from typing import List, Dict, Any, Optional +from pymilvus import MilvusClient, DataType + +from metagpt.document_store.base_store import BaseStore + +@dataclass +class MilvusConnection: + """ + Args: + uri: milvus url + token: milvus token + """ + + uri: str = None + token: str = None + + +class MilvusStore(BaseStore): + def __init__(self, connect: MilvusConnection): + if not connect.uri: + raise Exception("please check MilvusConnection, uri must be set.") + self.client = MilvusClient( + uri=connect.uri, + token=connect.token + ) + + def create_collection( + self, + collection_name: str, + dim: int, + enable_dynamic_schema: bool = True + ): + if self.client.has_collection(collection_name=collection_name): + self.client.drop_collection(collection_name=collection_name) + + schema = self.client.create_schema( + auto_id=False, + enable_dynamic_field=False, + ) + schema.add_field(field_name="id", datatype=DataType.VARCHAR, is_primary=True, max_length=36) + schema.add_field(field_name="vector", datatype=DataType.FLOAT_VECTOR, dim=dim) + + index_params = self.client.prepare_index_params() + index_params.add_index( + field_name="vector", + index_type="AUTOINDEX", + metric_type="COSINE" + ) + + self.client.create_collection( + collection_name=collection_name, + schema=schema, + index_params=index_params, + enable_dynamic_schema=enable_dynamic_schema + ) + + @staticmethod + def build_filter(key, value) -> str: + if isinstance(value, str): + filter_expression = f'{key} == "{value}"' + else: + if isinstance(value, list): + filter_expression = f'{key} in {value}' + else: + filter_expression = f'{key} == {value}' + + return filter_expression + + def search( + self, + collection_name: str, + query: List[float], + filter: Dict[str, str | int | list[int]] = None, + limit: int = 10, + output_fields: Optional[List[str]] = None, + ) -> List[dict]: + filter_expression = '' + + for key, value in filter.items(): + filter_expression += f'{self.build_filter(key, value)} and ' + print(filter_expression) + + res = self.client.search( + collection_name=collection_name, + data=[query], + filter=filter_expression, + limit=limit, + output_fields=output_fields, + )[0] + + return res + + def add( + self, + collection_name: str, + _ids: List[str], + vector: List[List[float]], + metadata: List[Dict[str, Any]] + ): + data = dict() + + for i, id in enumerate(_ids): + data['id'] = id + data['vector'] = vector[i] + data['metadata'] = metadata[i] + + self.client.upsert( + collection_name=collection_name, + data=data + ) + + def delete( + self, + collection_name: str, + _ids: List[str] + ): + self.client.delete( + collection_name=collection_name, + ids=_ids + ) + + def write(self, *args, **kwargs): + pass diff --git a/requirements.txt b/requirements.txt index 8bf0ee399..92f5654da 100644 --- a/requirements.txt +++ b/requirements.txt @@ -79,3 +79,4 @@ gymnasium==0.29.1 boto3~=1.34.69 spark_ai_python~=0.3.30 agentops +pymilvus==2.4.5 diff --git a/tests/metagpt/document_store/test_milvus_store.py b/tests/metagpt/document_store/test_milvus_store.py new file mode 100644 index 000000000..7cfd31381 --- /dev/null +++ b/tests/metagpt/document_store/test_milvus_store.py @@ -0,0 +1,48 @@ +import random +from metagpt.document_store.milvus_store import MilvusConnection, MilvusStore + +seed_value = 42 +random.seed(seed_value) + +vectors = [[random.random() for _ in range(8)] for _ in range(10)] +ids = [f"doc_{i}" for i in range(10)] +metadata = [{"color": "red", "rand_number": i % 10} for i in range(10)] + + +def assert_almost_equal(actual, expected): + delta = 1e-10 + if isinstance(expected, list): + assert len(actual) == len(expected) + for ac, exp in zip(actual, expected): + assert abs(ac - exp) <= delta, f"{ac} is not within {delta} of {exp}" + else: + assert abs(actual - expected) <= delta, f"{actual} is not within {delta} of {expected}" + + +def test_milvus_store(): + milvus_connection = MilvusConnection(uri="./milvus_local.db") + milvus_store = MilvusStore(milvus_connection) + + collection_name = "TestCollection" + milvus_store.create_collection(collection_name, dim=8) + + milvus_store.add(collection_name, ids, vectors, metadata) + + search_results = milvus_store.search(collection_name, query=[1.0] * 8) + assert len(search_results) > 0 + first_result = search_results[0] + assert first_result["id"] == "doc_0" + + search_results_with_filter = milvus_store.search( + collection_name, + query=[1.0] * 8, + filter={"rand_number": 1} + ) + assert len(search_results_with_filter) > 0 + assert search_results_with_filter[0]["id"] == "doc_1" + + milvus_store.delete(collection_name, _ids=["doc_0"]) + deleted_results = milvus_store.search(collection_name, query=[1.0] * 8, limit=1) + assert deleted_results[0]["id"] != "doc_0" + + milvus_store.client.drop_collection(collection_name)