mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-05-05 13:52:38 +02:00
lazy dependency for milvus
Signed-off-by: ChengZi <chen.zhang@zilliz.com>
This commit is contained in:
parent
e5f037f86d
commit
4d92fdcec9
2 changed files with 28 additions and 53 deletions
|
|
@ -1,9 +1,9 @@
|
|||
from dataclasses import dataclass
|
||||
from typing import List, Dict, Any, Optional
|
||||
from pymilvus import MilvusClient, DataType
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from metagpt.document_store.base_store import BaseStore
|
||||
|
||||
|
||||
@dataclass
|
||||
class MilvusConnection:
|
||||
"""
|
||||
|
|
@ -18,19 +18,17 @@ class MilvusConnection:
|
|||
|
||||
class MilvusStore(BaseStore):
|
||||
def __init__(self, connect: MilvusConnection):
|
||||
try:
|
||||
from pymilvus import MilvusClient
|
||||
except ImportError:
|
||||
raise Exception("Please install pymilvus first.")
|
||||
if not connect.uri:
|
||||
raise Exception("please check MilvusConnection, uri must be set.")
|
||||
self.client = MilvusClient(
|
||||
uri=connect.uri,
|
||||
token=connect.token
|
||||
)
|
||||
self.client = MilvusClient(uri=connect.uri, token=connect.token)
|
||||
|
||||
def create_collection(self, collection_name: str, dim: int, enable_dynamic_schema: bool = True):
|
||||
from pymilvus import DataType
|
||||
|
||||
def create_collection(
|
||||
self,
|
||||
collection_name: str,
|
||||
dim: int,
|
||||
enable_dynamic_schema: bool = True
|
||||
):
|
||||
if self.client.has_collection(collection_name=collection_name):
|
||||
self.client.drop_collection(collection_name=collection_name)
|
||||
|
||||
|
|
@ -42,17 +40,13 @@ class MilvusStore(BaseStore):
|
|||
schema.add_field(field_name="vector", datatype=DataType.FLOAT_VECTOR, dim=dim)
|
||||
|
||||
index_params = self.client.prepare_index_params()
|
||||
index_params.add_index(
|
||||
field_name="vector",
|
||||
index_type="AUTOINDEX",
|
||||
metric_type="COSINE"
|
||||
)
|
||||
index_params.add_index(field_name="vector", index_type="AUTOINDEX", metric_type="COSINE")
|
||||
|
||||
self.client.create_collection(
|
||||
collection_name=collection_name,
|
||||
schema=schema,
|
||||
index_params=index_params,
|
||||
enable_dynamic_schema=enable_dynamic_schema
|
||||
enable_dynamic_schema=enable_dynamic_schema,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
|
|
@ -61,9 +55,9 @@ class MilvusStore(BaseStore):
|
|||
filter_expression = f'{key} == "{value}"'
|
||||
else:
|
||||
if isinstance(value, list):
|
||||
filter_expression = f'{key} in {value}'
|
||||
filter_expression = f"{key} in {value}"
|
||||
else:
|
||||
filter_expression = f'{key} == {value}'
|
||||
filter_expression = f"{key} == {value}"
|
||||
|
||||
return filter_expression
|
||||
|
||||
|
|
@ -71,14 +65,11 @@ class MilvusStore(BaseStore):
|
|||
self,
|
||||
collection_name: str,
|
||||
query: List[float],
|
||||
filter: Dict[str, str | int | list[int]] = None,
|
||||
filter: Dict = None,
|
||||
limit: int = 10,
|
||||
output_fields: Optional[List[str]] = None,
|
||||
) -> List[dict]:
|
||||
filter_expression = ''
|
||||
|
||||
for key, value in filter.items():
|
||||
filter_expression += f'{self.build_filter(key, value)} and '
|
||||
filter_expression = " and ".join([self.build_filter(key, value) for key, value in filter.items()])
|
||||
print(filter_expression)
|
||||
|
||||
res = self.client.search(
|
||||
|
|
@ -91,34 +82,18 @@ class MilvusStore(BaseStore):
|
|||
|
||||
return res
|
||||
|
||||
def add(
|
||||
self,
|
||||
collection_name: str,
|
||||
_ids: List[str],
|
||||
vector: List[List[float]],
|
||||
metadata: List[Dict[str, Any]]
|
||||
):
|
||||
def add(self, collection_name: str, _ids: List[str], vector: List[List[float]], metadata: List[Dict[str, Any]]):
|
||||
data = dict()
|
||||
|
||||
for i, id in enumerate(_ids):
|
||||
data['id'] = id
|
||||
data['vector'] = vector[i]
|
||||
data['metadata'] = metadata[i]
|
||||
data["id"] = id
|
||||
data["vector"] = vector[i]
|
||||
data["metadata"] = metadata[i]
|
||||
|
||||
self.client.upsert(
|
||||
collection_name=collection_name,
|
||||
data=data
|
||||
)
|
||||
self.client.upsert(collection_name=collection_name, data=data)
|
||||
|
||||
def delete(
|
||||
self,
|
||||
collection_name: str,
|
||||
_ids: List[str]
|
||||
):
|
||||
self.client.delete(
|
||||
collection_name=collection_name,
|
||||
ids=_ids
|
||||
)
|
||||
def delete(self, collection_name: str, _ids: List[str]):
|
||||
self.client.delete(collection_name=collection_name, ids=_ids)
|
||||
|
||||
def write(self, *args, **kwargs):
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -1,4 +1,7 @@
|
|||
import random
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.document_store.milvus_store import MilvusConnection, MilvusStore
|
||||
|
||||
seed_value = 42
|
||||
|
|
@ -19,6 +22,7 @@ def assert_almost_equal(actual, expected):
|
|||
assert abs(actual - expected) <= delta, f"{actual} is not within {delta} of {expected}"
|
||||
|
||||
|
||||
@pytest.mark.skip() # Skip because the pymilvus dependency is not installed by default
|
||||
def test_milvus_store():
|
||||
milvus_connection = MilvusConnection(uri="./milvus_local.db")
|
||||
milvus_store = MilvusStore(milvus_connection)
|
||||
|
|
@ -33,11 +37,7 @@ def test_milvus_store():
|
|||
first_result = search_results[0]
|
||||
assert first_result["id"] == "doc_0"
|
||||
|
||||
search_results_with_filter = milvus_store.search(
|
||||
collection_name,
|
||||
query=[1.0] * 8,
|
||||
filter={"rand_number": 1}
|
||||
)
|
||||
search_results_with_filter = milvus_store.search(collection_name, query=[1.0] * 8, filter={"rand_number": 1})
|
||||
assert len(search_results_with_filter) > 0
|
||||
assert search_results_with_filter[0]["id"] == "doc_1"
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue