From dc38226b592e7970e44bdedd62d63633b02a8445 Mon Sep 17 00:00:00 2001 From: hezz Date: Sun, 13 Aug 2023 13:56:29 +0800 Subject: [PATCH 1/5] add qdrant store --- metagpt/document_store/base_store.py | 2 +- metagpt/document_store/qdrant_store.py | 90 +++++++ requirements.txt | 3 +- .../document_store/test_qdrant_store.py | 237 ++++++++++++++++++ 4 files changed, 330 insertions(+), 2 deletions(-) create mode 100644 metagpt/document_store/qdrant_store.py create mode 100644 tests/metagpt/document_store/test_qdrant_store.py diff --git a/metagpt/document_store/base_store.py b/metagpt/document_store/base_store.py index 01877e106..3dc96c0d6 100644 --- a/metagpt/document_store/base_store.py +++ b/metagpt/document_store/base_store.py @@ -15,7 +15,7 @@ class BaseStore(ABC): """FIXME: consider add_index, set_index and think 颗粒度""" @abstractmethod - def search(self, query, *args, **kwargs): + def search(self, *args, **kwargs): raise NotImplementedError @abstractmethod diff --git a/metagpt/document_store/qdrant_store.py b/metagpt/document_store/qdrant_store.py new file mode 100644 index 000000000..07791565e --- /dev/null +++ b/metagpt/document_store/qdrant_store.py @@ -0,0 +1,90 @@ +from dataclasses import dataclass +from typing import List + +from qdrant_client import QdrantClient +from qdrant_client.http.models import VectorParams +from qdrant_client.models import Filter, PointStruct + +from metagpt.document_store.base_store import BaseStore + + +@dataclass +class QdrantConnection: + url: str = None + host: str = None + port: int = None + memory: bool = False + api_key: str = None + + +class QdrantStore(BaseStore): + def __init__(self, connect: QdrantConnection): + if connect.memory: + self.client = QdrantClient(":memory:") + elif connect.url: + self.client = QdrantClient(url=connect.url, api_key=connect.api_key) + elif connect.host and connect.port: + self.client = QdrantClient( + host=connect.host, port=connect.port, api_key=connect.api_key + ) + else: + raise Exception("please check QdrantConnection.") + + def create_collection( + self, + collection_name: str, + vectors_config: VectorParams, + force_recreate=False, + **kwargs, + ): + try: + self.client.get_collection(collection_name) + if force_recreate: + res = self.client.recreate_collection( + collection_name, vectors_config=vectors_config, **kwargs + ) + return res + return True + except: # noqa: E722 + return self.client.recreate_collection( + collection_name, vectors_config=vectors_config, **kwargs + ) + + def has_collection(self, collection_name: str): + try: + self.client.get_collection(collection_name) + return True + except: # noqa: E722 + return False + + def delete_collection(self, collection_name: str, timeout=60): + res = self.client.delete_collection(collection_name, timeout=timeout) + if not res: + raise Exception(f"Delete collection {collection_name} failed.") + + def add(self, collection_name: str, points: List[PointStruct]): + # self.client.upload_records() + self.client.upsert( + collection_name, + points, + ) + + def search( + self, + collection_name: str, + query: List[float], + query_filter: Filter = None, + k=10, + return_vector=False, + ): + hits = self.client.search( + collection_name=collection_name, + query_vector=query, + query_filter=query_filter, + limit=k, + with_vectors=return_vector, + ) + return [hit.__dict__ for hit in hits] + + def write(self, *args, **kwargs): + pass diff --git a/requirements.txt b/requirements.txt index 452e2d092..c18145b98 100644 --- a/requirements.txt +++ b/requirements.txt @@ -17,7 +17,7 @@ numpy==1.24.3 openai==0.27.8 openpyxl pandas==1.4.1 -pydantic==1.10.7 +pydantic==1.10.8 #pygame==2.1.3 #pymilvus==2.2.8 pytest==7.2.2 @@ -36,3 +36,4 @@ anthropic==0.3.6 typing-inspect==0.8.0 typing_extensions==4.5.0 libcst==1.0.1 +qdrant-client==1.4.0 \ No newline at end of file diff --git a/tests/metagpt/document_store/test_qdrant_store.py b/tests/metagpt/document_store/test_qdrant_store.py new file mode 100644 index 000000000..46f0ea376 --- /dev/null +++ b/tests/metagpt/document_store/test_qdrant_store.py @@ -0,0 +1,237 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2023/6/11 21:08 +@Author : hezhaozhao +@File : test_qdrant_store.py +""" +import random + +from qdrant_client.models import ( + Distance, + FieldCondition, + Filter, + PointStruct, + Range, + VectorParams, +) + +from metagpt.document_store.qdrant_store import QdrantConnection, QdrantStore + +seed_value = 42 +random.seed(seed_value) + +vectors = [[random.random() for _ in range(2)] for _ in range(10)] + +points = [ + PointStruct( + id=idx, vector=vector, payload={"color": "red", "rand_number": idx % 10} + ) + for idx, vector in enumerate(vectors) +] + + +def test_milvus_store(): + qdrant_connection = QdrantConnection(memory=True) + vectors_config = VectorParams(size=2, distance=Distance.COSINE) + qdrant_store = QdrantStore(qdrant_connection) + qdrant_store.create_collection("Book", vectors_config, force_recreate=True) + assert qdrant_store.has_collection("Book") is True + qdrant_store.delete_collection("Book") + assert qdrant_store.has_collection("Book") is False + qdrant_store.create_collection("Book", vectors_config) + assert qdrant_store.has_collection("Book") is True + qdrant_store.add("Book", points) + results = qdrant_store.search("Book", query=[1.0, 1.0]) + assert results == [ + { + "id": 2, + "version": 0, + "score": 0.999106722578389, + "payload": {"color": "red", "rand_number": 2}, + "vector": None, + }, + { + "id": 7, + "version": 0, + "score": 0.9961650411397226, + "payload": {"color": "red", "rand_number": 7}, + "vector": None, + }, + { + "id": 1, + "version": 0, + "score": 0.9946351526856256, + "payload": {"color": "red", "rand_number": 1}, + "vector": None, + }, + { + "id": 5, + "version": 0, + "score": 0.9297466022881021, + "payload": {"color": "red", "rand_number": 5}, + "vector": None, + }, + { + "id": 8, + "version": 0, + "score": 0.9100373450784073, + "payload": {"color": "red", "rand_number": 8}, + "vector": None, + }, + { + "id": 6, + "version": 0, + "score": 0.7944306996390111, + "payload": {"color": "red", "rand_number": 6}, + "vector": None, + }, + { + "id": 3, + "version": 0, + "score": 0.7723528053480722, + "payload": {"color": "red", "rand_number": 3}, + "vector": None, + }, + { + "id": 4, + "version": 0, + "score": 0.755163629383033, + "payload": {"color": "red", "rand_number": 4}, + "vector": None, + }, + { + "id": 0, + "version": 0, + "score": 0.73420337995255, + "payload": {"color": "red", "rand_number": 0}, + "vector": None, + }, + { + "id": 9, + "version": 0, + "score": 0.7127610621127889, + "payload": {"color": "red", "rand_number": 9}, + "vector": None, + }, + ] + results = qdrant_store.search("Book", query=[1.0, 1.0], return_vector=True) + assert results == [ + { + "id": 2, + "version": 0, + "score": 0.999106722578389, + "payload": {"color": "red", "rand_number": 2}, + "vector": [0.7363563179969788, 0.6765939593315125], + }, + { + "id": 7, + "version": 0, + "score": 0.9961650411397226, + "payload": {"color": "red", "rand_number": 7}, + "vector": [0.7662628889083862, 0.6425272226333618], + }, + { + "id": 1, + "version": 0, + "score": 0.9946351526856256, + "payload": {"color": "red", "rand_number": 1}, + "vector": [0.7764601111412048, 0.6301664113998413], + }, + { + "id": 5, + "version": 0, + "score": 0.9297466022881021, + "payload": {"color": "red", "rand_number": 5}, + "vector": [0.39707326889038086, 0.9177868962287903], + }, + { + "id": 8, + "version": 0, + "score": 0.9100373450784073, + "payload": {"color": "red", "rand_number": 8}, + "vector": [0.35037919878959656, 0.9366079568862915], + }, + { + "id": 6, + "version": 0, + "score": 0.7944306996390111, + "payload": {"color": "red", "rand_number": 6}, + "vector": [0.13228265941143036, 0.991212010383606], + }, + { + "id": 3, + "version": 0, + "score": 0.7723528053480722, + "payload": {"color": "red", "rand_number": 3}, + "vector": [0.9952857494354248, 0.0969860628247261], + }, + { + "id": 4, + "version": 0, + "score": 0.755163629383033, + "payload": {"color": "red", "rand_number": 4}, + "vector": [0.9975154995918274, 0.07044714689254761], + }, + { + "id": 0, + "version": 0, + "score": 0.73420337995255, + "payload": {"color": "red", "rand_number": 0}, + "vector": [0.9992359280586243, 0.03908444941043854], + }, + { + "id": 9, + "version": 0, + "score": 0.7127610621127889, + "payload": {"color": "red", "rand_number": 9}, + "vector": [0.9999677538871765, 0.00802854634821415], + }, + ] + results = qdrant_store.search( + "Book", + query=[1.0, 1.0], + query_filter=Filter( + must=[FieldCondition(key="rand_number", range=Range(gte=8))] + ), + ) + assert results == [ + { + "id": 8, + "version": 0, + "score": 0.9100373450784073, + "payload": {"color": "red", "rand_number": 8}, + "vector": None, + }, + { + "id": 9, + "version": 0, + "score": 0.7127610621127889, + "payload": {"color": "red", "rand_number": 9}, + "vector": None, + }, + ] + results = qdrant_store.search( + "Book", + query=[1.0, 1.0], + query_filter=Filter( + must=[FieldCondition(key="rand_number", range=Range(gte=8))] + ), + return_vector=True, + ) + assert results == [ + { + "id": 8, + "version": 0, + "score": 0.9100373450784073, + "payload": {"color": "red", "rand_number": 8}, + "vector": [0.35037919878959656, 0.9366079568862915], + }, + { + "id": 9, + "version": 0, + "score": 0.7127610621127889, + "payload": {"color": "red", "rand_number": 9}, + "vector": [0.9999677538871765, 0.00802854634821415], + }, + ] From bb7ea6c39878f92412f421cd60736b0e8cabee75 Mon Sep 17 00:00:00 2001 From: hezz Date: Sun, 13 Aug 2023 21:19:43 +0800 Subject: [PATCH 2/5] change qdrant_store.py --- .../document_store/test_qdrant_store.py | 192 ++---------------- 1 file changed, 16 insertions(+), 176 deletions(-) diff --git a/tests/metagpt/document_store/test_qdrant_store.py b/tests/metagpt/document_store/test_qdrant_store.py index 46f0ea376..a63a4329d 100644 --- a/tests/metagpt/document_store/test_qdrant_store.py +++ b/tests/metagpt/document_store/test_qdrant_store.py @@ -43,151 +43,17 @@ def test_milvus_store(): assert qdrant_store.has_collection("Book") is True qdrant_store.add("Book", points) results = qdrant_store.search("Book", query=[1.0, 1.0]) - assert results == [ - { - "id": 2, - "version": 0, - "score": 0.999106722578389, - "payload": {"color": "red", "rand_number": 2}, - "vector": None, - }, - { - "id": 7, - "version": 0, - "score": 0.9961650411397226, - "payload": {"color": "red", "rand_number": 7}, - "vector": None, - }, - { - "id": 1, - "version": 0, - "score": 0.9946351526856256, - "payload": {"color": "red", "rand_number": 1}, - "vector": None, - }, - { - "id": 5, - "version": 0, - "score": 0.9297466022881021, - "payload": {"color": "red", "rand_number": 5}, - "vector": None, - }, - { - "id": 8, - "version": 0, - "score": 0.9100373450784073, - "payload": {"color": "red", "rand_number": 8}, - "vector": None, - }, - { - "id": 6, - "version": 0, - "score": 0.7944306996390111, - "payload": {"color": "red", "rand_number": 6}, - "vector": None, - }, - { - "id": 3, - "version": 0, - "score": 0.7723528053480722, - "payload": {"color": "red", "rand_number": 3}, - "vector": None, - }, - { - "id": 4, - "version": 0, - "score": 0.755163629383033, - "payload": {"color": "red", "rand_number": 4}, - "vector": None, - }, - { - "id": 0, - "version": 0, - "score": 0.73420337995255, - "payload": {"color": "red", "rand_number": 0}, - "vector": None, - }, - { - "id": 9, - "version": 0, - "score": 0.7127610621127889, - "payload": {"color": "red", "rand_number": 9}, - "vector": None, - }, - ] + assert results[0]["id"] == 2 + assert results[0]["score"] == 0.999106722578389 + assert results[1]["score"] == 7 + assert results[1]["score"] == 0.9961650411397226 results = qdrant_store.search("Book", query=[1.0, 1.0], return_vector=True) - assert results == [ - { - "id": 2, - "version": 0, - "score": 0.999106722578389, - "payload": {"color": "red", "rand_number": 2}, - "vector": [0.7363563179969788, 0.6765939593315125], - }, - { - "id": 7, - "version": 0, - "score": 0.9961650411397226, - "payload": {"color": "red", "rand_number": 7}, - "vector": [0.7662628889083862, 0.6425272226333618], - }, - { - "id": 1, - "version": 0, - "score": 0.9946351526856256, - "payload": {"color": "red", "rand_number": 1}, - "vector": [0.7764601111412048, 0.6301664113998413], - }, - { - "id": 5, - "version": 0, - "score": 0.9297466022881021, - "payload": {"color": "red", "rand_number": 5}, - "vector": [0.39707326889038086, 0.9177868962287903], - }, - { - "id": 8, - "version": 0, - "score": 0.9100373450784073, - "payload": {"color": "red", "rand_number": 8}, - "vector": [0.35037919878959656, 0.9366079568862915], - }, - { - "id": 6, - "version": 0, - "score": 0.7944306996390111, - "payload": {"color": "red", "rand_number": 6}, - "vector": [0.13228265941143036, 0.991212010383606], - }, - { - "id": 3, - "version": 0, - "score": 0.7723528053480722, - "payload": {"color": "red", "rand_number": 3}, - "vector": [0.9952857494354248, 0.0969860628247261], - }, - { - "id": 4, - "version": 0, - "score": 0.755163629383033, - "payload": {"color": "red", "rand_number": 4}, - "vector": [0.9975154995918274, 0.07044714689254761], - }, - { - "id": 0, - "version": 0, - "score": 0.73420337995255, - "payload": {"color": "red", "rand_number": 0}, - "vector": [0.9992359280586243, 0.03908444941043854], - }, - { - "id": 9, - "version": 0, - "score": 0.7127610621127889, - "payload": {"color": "red", "rand_number": 9}, - "vector": [0.9999677538871765, 0.00802854634821415], - }, - ] + assert results[0]["id"] == 2 + assert results[0]["score"] == 0.999106722578389 + assert results[0]["vector"] == [0.7363563179969788, 0.6765939593315125] + assert results[1]["score"] == 7 + assert results[1]["score"] == 0.9961650411397226 + assert results[1]["vector"] == [0.7662628889083862, 0.6425272226333618] results = qdrant_store.search( "Book", query=[1.0, 1.0], @@ -195,22 +61,10 @@ def test_milvus_store(): must=[FieldCondition(key="rand_number", range=Range(gte=8))] ), ) - assert results == [ - { - "id": 8, - "version": 0, - "score": 0.9100373450784073, - "payload": {"color": "red", "rand_number": 8}, - "vector": None, - }, - { - "id": 9, - "version": 0, - "score": 0.7127610621127889, - "payload": {"color": "red", "rand_number": 9}, - "vector": None, - }, - ] + assert results[0]["id"] == 8 + assert results[0]["score"] == 0.9100373450784073 + assert results[1]["id"] == 9 + assert results[1]["score"] == 0.7127610621127889 results = qdrant_store.search( "Book", query=[1.0, 1.0], @@ -219,19 +73,5 @@ def test_milvus_store(): ), return_vector=True, ) - assert results == [ - { - "id": 8, - "version": 0, - "score": 0.9100373450784073, - "payload": {"color": "red", "rand_number": 8}, - "vector": [0.35037919878959656, 0.9366079568862915], - }, - { - "id": 9, - "version": 0, - "score": 0.7127610621127889, - "payload": {"color": "red", "rand_number": 9}, - "vector": [0.9999677538871765, 0.00802854634821415], - }, - ] + assert results[0]["vector"] == [0.35037919878959656, 0.9366079568862915] + assert results[1]["vector"] == [0.9999677538871765, 0.00802854634821415] From 80506ec3cebafb815b3056b0cfeea49ab3dbec7c Mon Sep 17 00:00:00 2001 From: root Date: Mon, 14 Aug 2023 21:59:48 +0800 Subject: [PATCH 3/5] add param annotations for qdrant --- metagpt/document_store/qdrant_store.py | 40 ++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/metagpt/document_store/qdrant_store.py b/metagpt/document_store/qdrant_store.py index 07791565e..9968794ef 100644 --- a/metagpt/document_store/qdrant_store.py +++ b/metagpt/document_store/qdrant_store.py @@ -10,6 +10,14 @@ from metagpt.document_store.base_store import BaseStore @dataclass class QdrantConnection: + """ + Args: + url: qdrant url + host: qdrant host + port: qdrant port + memory: qdrant service use memory mode + api_key: qdrant cloud api_key + """ url: str = None host: str = None port: int = None @@ -37,6 +45,17 @@ class QdrantStore(BaseStore): force_recreate=False, **kwargs, ): + """ + create a collection + Args: + collection_name: collection name + vectors_config: VectorParams object,detail in https://github.com/qdrant/qdrant-client + force_recreate: default is False, if True, will delete exists collection,then create it + **kwargs: + + Returns: + + """ try: self.client.get_collection(collection_name) if force_recreate: @@ -63,6 +82,15 @@ class QdrantStore(BaseStore): raise Exception(f"Delete collection {collection_name} failed.") def add(self, collection_name: str, points: List[PointStruct]): + """ + add some vector data to qdrant + Args: + collection_name: collection name + points: list of PointStruct object, about PointStruct detail in https://github.com/qdrant/qdrant-client + + Returns: None + + """ # self.client.upload_records() self.client.upsert( collection_name, @@ -77,6 +105,18 @@ class QdrantStore(BaseStore): k=10, return_vector=False, ): + """ + vector search + Args: + collection_name: qdrant collection name + query: input vector + query_filter: Filter object, detail in https://github.com/qdrant/qdrant-client + k: return the most similar k pieces of data + return_vector: whether return vector + + Returns: list of dict + + """ hits = self.client.search( collection_name=collection_name, query_vector=query, From 20adbdcf76d8b1dbdb7f155249f7885fb20de02e Mon Sep 17 00:00:00 2001 From: root Date: Mon, 14 Aug 2023 22:00:34 +0800 Subject: [PATCH 4/5] add param annotations for qdrant --- metagpt/document_store/qdrant_store.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/metagpt/document_store/qdrant_store.py b/metagpt/document_store/qdrant_store.py index 9968794ef..87779a561 100644 --- a/metagpt/document_store/qdrant_store.py +++ b/metagpt/document_store/qdrant_store.py @@ -88,7 +88,7 @@ class QdrantStore(BaseStore): collection_name: collection name points: list of PointStruct object, about PointStruct detail in https://github.com/qdrant/qdrant-client - Returns: None + Returns: NoneX """ # self.client.upload_records() From 5c5100761078b23e44c0f21b42999c641231b297 Mon Sep 17 00:00:00 2001 From: hezhaozhao Date: Mon, 14 Aug 2023 22:20:46 +0800 Subject: [PATCH 5/5] optimize qdrant models import --- metagpt/document_store/qdrant_store.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/metagpt/document_store/qdrant_store.py b/metagpt/document_store/qdrant_store.py index 87779a561..98b82cf87 100644 --- a/metagpt/document_store/qdrant_store.py +++ b/metagpt/document_store/qdrant_store.py @@ -2,8 +2,7 @@ from dataclasses import dataclass from typing import List from qdrant_client import QdrantClient -from qdrant_client.http.models import VectorParams -from qdrant_client.models import Filter, PointStruct +from qdrant_client.models import Filter, PointStruct, VectorParams from metagpt.document_store.base_store import BaseStore