integrated milvus

This commit is contained in:
Jacksonxhx 2024-08-14 16:27:30 +08:00
parent ab846f65e4
commit 490203d20f
3 changed files with 173 additions and 0 deletions

View file

@ -0,0 +1,124 @@
from dataclasses import dataclass
from typing import List, Dict, Any, Optional
from pymilvus import MilvusClient, DataType
from metagpt.document_store.base_store import BaseStore
@dataclass
class MilvusConnection:
"""
Args:
uri: milvus url
token: milvus token
"""
uri: str = None
token: str = None
class MilvusStore(BaseStore):
def __init__(self, connect: MilvusConnection):
if not connect.uri:
raise Exception("please check MilvusConnection, uri must be set.")
self.client = MilvusClient(
uri=connect.uri,
token=connect.token
)
def create_collection(
self,
collection_name: str,
dim: int,
enable_dynamic_schema: bool = True
):
if self.client.has_collection(collection_name=collection_name):
self.client.drop_collection(collection_name=collection_name)
schema = self.client.create_schema(
auto_id=False,
enable_dynamic_field=False,
)
schema.add_field(field_name="id", datatype=DataType.VARCHAR, is_primary=True, max_length=36)
schema.add_field(field_name="vector", datatype=DataType.FLOAT_VECTOR, dim=dim)
index_params = self.client.prepare_index_params()
index_params.add_index(
field_name="vector",
index_type="AUTOINDEX",
metric_type="COSINE"
)
self.client.create_collection(
collection_name=collection_name,
schema=schema,
index_params=index_params,
enable_dynamic_schema=enable_dynamic_schema
)
@staticmethod
def build_filter(key, value) -> str:
if isinstance(value, str):
filter_expression = f'{key} == "{value}"'
else:
if isinstance(value, list):
filter_expression = f'{key} in {value}'
else:
filter_expression = f'{key} == {value}'
return filter_expression
def search(
self,
collection_name: str,
query: List[float],
filter: Dict[str, str | int | list[int]] = None,
limit: int = 10,
output_fields: Optional[List[str]] = None,
) -> List[dict]:
filter_expression = ''
for key, value in filter.items():
filter_expression += f'{self.build_filter(key, value)} and '
print(filter_expression)
res = self.client.search(
collection_name=collection_name,
data=[query],
filter=filter_expression,
limit=limit,
output_fields=output_fields,
)[0]
return res
def add(
self,
collection_name: str,
_ids: List[str],
vector: List[List[float]],
metadata: List[Dict[str, Any]]
):
data = dict()
for i, id in enumerate(_ids):
data['id'] = id
data['vector'] = vector[i]
data['metadata'] = metadata[i]
self.client.upsert(
collection_name=collection_name,
data=data
)
def delete(
self,
collection_name: str,
_ids: List[str]
):
self.client.delete(
collection_name=collection_name,
ids=_ids
)
def write(self, *args, **kwargs):
pass

View file

@ -79,3 +79,4 @@ gymnasium==0.29.1
boto3~=1.34.69
spark_ai_python~=0.3.30
agentops
pymilvus==2.4.5

View file

@ -0,0 +1,48 @@
import random
from metagpt.document_store.milvus_store import MilvusConnection, MilvusStore
seed_value = 42
random.seed(seed_value)
vectors = [[random.random() for _ in range(8)] for _ in range(10)]
ids = [f"doc_{i}" for i in range(10)]
metadata = [{"color": "red", "rand_number": i % 10} for i in range(10)]
def assert_almost_equal(actual, expected):
delta = 1e-10
if isinstance(expected, list):
assert len(actual) == len(expected)
for ac, exp in zip(actual, expected):
assert abs(ac - exp) <= delta, f"{ac} is not within {delta} of {exp}"
else:
assert abs(actual - expected) <= delta, f"{actual} is not within {delta} of {expected}"
def test_milvus_store():
milvus_connection = MilvusConnection(uri="./milvus_local.db")
milvus_store = MilvusStore(milvus_connection)
collection_name = "TestCollection"
milvus_store.create_collection(collection_name, dim=8)
milvus_store.add(collection_name, ids, vectors, metadata)
search_results = milvus_store.search(collection_name, query=[1.0] * 8)
assert len(search_results) > 0
first_result = search_results[0]
assert first_result["id"] == "doc_0"
search_results_with_filter = milvus_store.search(
collection_name,
query=[1.0] * 8,
filter={"rand_number": 1}
)
assert len(search_results_with_filter) > 0
assert search_results_with_filter[0]["id"] == "doc_1"
milvus_store.delete(collection_name, _ids=["doc_0"])
deleted_results = milvus_store.search(collection_name, query=[1.0] * 8, limit=1)
assert deleted_results[0]["id"] != "doc_0"
milvus_store.client.drop_collection(collection_name)