mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-04-25 00:36:55 +02:00
integrated milvus
This commit is contained in:
parent
ab846f65e4
commit
490203d20f
3 changed files with 173 additions and 0 deletions
124
metagpt/document_store/milvus_store.py
Normal file
124
metagpt/document_store/milvus_store.py
Normal file
|
|
@ -0,0 +1,124 @@
|
|||
from dataclasses import dataclass
|
||||
from typing import List, Dict, Any, Optional
|
||||
from pymilvus import MilvusClient, DataType
|
||||
|
||||
from metagpt.document_store.base_store import BaseStore
|
||||
|
||||
@dataclass
|
||||
class MilvusConnection:
|
||||
"""
|
||||
Args:
|
||||
uri: milvus url
|
||||
token: milvus token
|
||||
"""
|
||||
|
||||
uri: str = None
|
||||
token: str = None
|
||||
|
||||
|
||||
class MilvusStore(BaseStore):
|
||||
def __init__(self, connect: MilvusConnection):
|
||||
if not connect.uri:
|
||||
raise Exception("please check MilvusConnection, uri must be set.")
|
||||
self.client = MilvusClient(
|
||||
uri=connect.uri,
|
||||
token=connect.token
|
||||
)
|
||||
|
||||
def create_collection(
|
||||
self,
|
||||
collection_name: str,
|
||||
dim: int,
|
||||
enable_dynamic_schema: bool = True
|
||||
):
|
||||
if self.client.has_collection(collection_name=collection_name):
|
||||
self.client.drop_collection(collection_name=collection_name)
|
||||
|
||||
schema = self.client.create_schema(
|
||||
auto_id=False,
|
||||
enable_dynamic_field=False,
|
||||
)
|
||||
schema.add_field(field_name="id", datatype=DataType.VARCHAR, is_primary=True, max_length=36)
|
||||
schema.add_field(field_name="vector", datatype=DataType.FLOAT_VECTOR, dim=dim)
|
||||
|
||||
index_params = self.client.prepare_index_params()
|
||||
index_params.add_index(
|
||||
field_name="vector",
|
||||
index_type="AUTOINDEX",
|
||||
metric_type="COSINE"
|
||||
)
|
||||
|
||||
self.client.create_collection(
|
||||
collection_name=collection_name,
|
||||
schema=schema,
|
||||
index_params=index_params,
|
||||
enable_dynamic_schema=enable_dynamic_schema
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def build_filter(key, value) -> str:
|
||||
if isinstance(value, str):
|
||||
filter_expression = f'{key} == "{value}"'
|
||||
else:
|
||||
if isinstance(value, list):
|
||||
filter_expression = f'{key} in {value}'
|
||||
else:
|
||||
filter_expression = f'{key} == {value}'
|
||||
|
||||
return filter_expression
|
||||
|
||||
def search(
|
||||
self,
|
||||
collection_name: str,
|
||||
query: List[float],
|
||||
filter: Dict[str, str | int | list[int]] = None,
|
||||
limit: int = 10,
|
||||
output_fields: Optional[List[str]] = None,
|
||||
) -> List[dict]:
|
||||
filter_expression = ''
|
||||
|
||||
for key, value in filter.items():
|
||||
filter_expression += f'{self.build_filter(key, value)} and '
|
||||
print(filter_expression)
|
||||
|
||||
res = self.client.search(
|
||||
collection_name=collection_name,
|
||||
data=[query],
|
||||
filter=filter_expression,
|
||||
limit=limit,
|
||||
output_fields=output_fields,
|
||||
)[0]
|
||||
|
||||
return res
|
||||
|
||||
def add(
|
||||
self,
|
||||
collection_name: str,
|
||||
_ids: List[str],
|
||||
vector: List[List[float]],
|
||||
metadata: List[Dict[str, Any]]
|
||||
):
|
||||
data = dict()
|
||||
|
||||
for i, id in enumerate(_ids):
|
||||
data['id'] = id
|
||||
data['vector'] = vector[i]
|
||||
data['metadata'] = metadata[i]
|
||||
|
||||
self.client.upsert(
|
||||
collection_name=collection_name,
|
||||
data=data
|
||||
)
|
||||
|
||||
def delete(
|
||||
self,
|
||||
collection_name: str,
|
||||
_ids: List[str]
|
||||
):
|
||||
self.client.delete(
|
||||
collection_name=collection_name,
|
||||
ids=_ids
|
||||
)
|
||||
|
||||
def write(self, *args, **kwargs):
|
||||
pass
|
||||
|
|
@ -79,3 +79,4 @@ gymnasium==0.29.1
|
|||
boto3~=1.34.69
|
||||
spark_ai_python~=0.3.30
|
||||
agentops
|
||||
pymilvus==2.4.5
|
||||
|
|
|
|||
48
tests/metagpt/document_store/test_milvus_store.py
Normal file
48
tests/metagpt/document_store/test_milvus_store.py
Normal file
|
|
@ -0,0 +1,48 @@
|
|||
import random
|
||||
from metagpt.document_store.milvus_store import MilvusConnection, MilvusStore
|
||||
|
||||
seed_value = 42
|
||||
random.seed(seed_value)
|
||||
|
||||
vectors = [[random.random() for _ in range(8)] for _ in range(10)]
|
||||
ids = [f"doc_{i}" for i in range(10)]
|
||||
metadata = [{"color": "red", "rand_number": i % 10} for i in range(10)]
|
||||
|
||||
|
||||
def assert_almost_equal(actual, expected):
|
||||
delta = 1e-10
|
||||
if isinstance(expected, list):
|
||||
assert len(actual) == len(expected)
|
||||
for ac, exp in zip(actual, expected):
|
||||
assert abs(ac - exp) <= delta, f"{ac} is not within {delta} of {exp}"
|
||||
else:
|
||||
assert abs(actual - expected) <= delta, f"{actual} is not within {delta} of {expected}"
|
||||
|
||||
|
||||
def test_milvus_store():
|
||||
milvus_connection = MilvusConnection(uri="./milvus_local.db")
|
||||
milvus_store = MilvusStore(milvus_connection)
|
||||
|
||||
collection_name = "TestCollection"
|
||||
milvus_store.create_collection(collection_name, dim=8)
|
||||
|
||||
milvus_store.add(collection_name, ids, vectors, metadata)
|
||||
|
||||
search_results = milvus_store.search(collection_name, query=[1.0] * 8)
|
||||
assert len(search_results) > 0
|
||||
first_result = search_results[0]
|
||||
assert first_result["id"] == "doc_0"
|
||||
|
||||
search_results_with_filter = milvus_store.search(
|
||||
collection_name,
|
||||
query=[1.0] * 8,
|
||||
filter={"rand_number": 1}
|
||||
)
|
||||
assert len(search_results_with_filter) > 0
|
||||
assert search_results_with_filter[0]["id"] == "doc_1"
|
||||
|
||||
milvus_store.delete(collection_name, _ids=["doc_0"])
|
||||
deleted_results = milvus_store.search(collection_name, query=[1.0] * 8, limit=1)
|
||||
assert deleted_results[0]["id"] != "doc_0"
|
||||
|
||||
milvus_store.client.drop_collection(collection_name)
|
||||
Loading…
Add table
Add a link
Reference in a new issue