From dc38226b592e7970e44bdedd62d63633b02a8445 Mon Sep 17 00:00:00 2001 From: hezz Date: Sun, 13 Aug 2023 13:56:29 +0800 Subject: [PATCH] add qdrant store --- metagpt/document_store/base_store.py | 2 +- metagpt/document_store/qdrant_store.py | 90 +++++++ requirements.txt | 3 +- .../document_store/test_qdrant_store.py | 237 ++++++++++++++++++ 4 files changed, 330 insertions(+), 2 deletions(-) create mode 100644 metagpt/document_store/qdrant_store.py create mode 100644 tests/metagpt/document_store/test_qdrant_store.py diff --git a/metagpt/document_store/base_store.py b/metagpt/document_store/base_store.py index 01877e106..3dc96c0d6 100644 --- a/metagpt/document_store/base_store.py +++ b/metagpt/document_store/base_store.py @@ -15,7 +15,7 @@ class BaseStore(ABC): """FIXME: consider add_index, set_index and think 颗粒度""" @abstractmethod - def search(self, query, *args, **kwargs): + def search(self, *args, **kwargs): raise NotImplementedError @abstractmethod diff --git a/metagpt/document_store/qdrant_store.py b/metagpt/document_store/qdrant_store.py new file mode 100644 index 000000000..07791565e --- /dev/null +++ b/metagpt/document_store/qdrant_store.py @@ -0,0 +1,90 @@ +from dataclasses import dataclass +from typing import List + +from qdrant_client import QdrantClient +from qdrant_client.http.models import VectorParams +from qdrant_client.models import Filter, PointStruct + +from metagpt.document_store.base_store import BaseStore + + +@dataclass +class QdrantConnection: + url: str = None + host: str = None + port: int = None + memory: bool = False + api_key: str = None + + +class QdrantStore(BaseStore): + def __init__(self, connect: QdrantConnection): + if connect.memory: + self.client = QdrantClient(":memory:") + elif connect.url: + self.client = QdrantClient(url=connect.url, api_key=connect.api_key) + elif connect.host and connect.port: + self.client = QdrantClient( + host=connect.host, port=connect.port, api_key=connect.api_key + ) + else: + raise Exception("please check QdrantConnection.") + + def create_collection( + self, + collection_name: str, + vectors_config: VectorParams, + force_recreate=False, + **kwargs, + ): + try: + self.client.get_collection(collection_name) + if force_recreate: + res = self.client.recreate_collection( + collection_name, vectors_config=vectors_config, **kwargs + ) + return res + return True + except: # noqa: E722 + return self.client.recreate_collection( + collection_name, vectors_config=vectors_config, **kwargs + ) + + def has_collection(self, collection_name: str): + try: + self.client.get_collection(collection_name) + return True + except: # noqa: E722 + return False + + def delete_collection(self, collection_name: str, timeout=60): + res = self.client.delete_collection(collection_name, timeout=timeout) + if not res: + raise Exception(f"Delete collection {collection_name} failed.") + + def add(self, collection_name: str, points: List[PointStruct]): + # self.client.upload_records() + self.client.upsert( + collection_name, + points, + ) + + def search( + self, + collection_name: str, + query: List[float], + query_filter: Filter = None, + k=10, + return_vector=False, + ): + hits = self.client.search( + collection_name=collection_name, + query_vector=query, + query_filter=query_filter, + limit=k, + with_vectors=return_vector, + ) + return [hit.__dict__ for hit in hits] + + def write(self, *args, **kwargs): + pass diff --git a/requirements.txt b/requirements.txt index 452e2d092..c18145b98 100644 --- a/requirements.txt +++ b/requirements.txt @@ -17,7 +17,7 @@ numpy==1.24.3 openai==0.27.8 openpyxl pandas==1.4.1 -pydantic==1.10.7 +pydantic==1.10.8 #pygame==2.1.3 #pymilvus==2.2.8 pytest==7.2.2 @@ -36,3 +36,4 @@ anthropic==0.3.6 typing-inspect==0.8.0 typing_extensions==4.5.0 libcst==1.0.1 +qdrant-client==1.4.0 \ No newline at end of file diff --git a/tests/metagpt/document_store/test_qdrant_store.py b/tests/metagpt/document_store/test_qdrant_store.py new file mode 100644 index 000000000..46f0ea376 --- /dev/null +++ b/tests/metagpt/document_store/test_qdrant_store.py @@ -0,0 +1,237 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2023/6/11 21:08 +@Author : hezhaozhao +@File : test_qdrant_store.py +""" +import random + +from qdrant_client.models import ( + Distance, + FieldCondition, + Filter, + PointStruct, + Range, + VectorParams, +) + +from metagpt.document_store.qdrant_store import QdrantConnection, QdrantStore + +seed_value = 42 +random.seed(seed_value) + +vectors = [[random.random() for _ in range(2)] for _ in range(10)] + +points = [ + PointStruct( + id=idx, vector=vector, payload={"color": "red", "rand_number": idx % 10} + ) + for idx, vector in enumerate(vectors) +] + + +def test_milvus_store(): + qdrant_connection = QdrantConnection(memory=True) + vectors_config = VectorParams(size=2, distance=Distance.COSINE) + qdrant_store = QdrantStore(qdrant_connection) + qdrant_store.create_collection("Book", vectors_config, force_recreate=True) + assert qdrant_store.has_collection("Book") is True + qdrant_store.delete_collection("Book") + assert qdrant_store.has_collection("Book") is False + qdrant_store.create_collection("Book", vectors_config) + assert qdrant_store.has_collection("Book") is True + qdrant_store.add("Book", points) + results = qdrant_store.search("Book", query=[1.0, 1.0]) + assert results == [ + { + "id": 2, + "version": 0, + "score": 0.999106722578389, + "payload": {"color": "red", "rand_number": 2}, + "vector": None, + }, + { + "id": 7, + "version": 0, + "score": 0.9961650411397226, + "payload": {"color": "red", "rand_number": 7}, + "vector": None, + }, + { + "id": 1, + "version": 0, + "score": 0.9946351526856256, + "payload": {"color": "red", "rand_number": 1}, + "vector": None, + }, + { + "id": 5, + "version": 0, + "score": 0.9297466022881021, + "payload": {"color": "red", "rand_number": 5}, + "vector": None, + }, + { + "id": 8, + "version": 0, + "score": 0.9100373450784073, + "payload": {"color": "red", "rand_number": 8}, + "vector": None, + }, + { + "id": 6, + "version": 0, + "score": 0.7944306996390111, + "payload": {"color": "red", "rand_number": 6}, + "vector": None, + }, + { + "id": 3, + "version": 0, + "score": 0.7723528053480722, + "payload": {"color": "red", "rand_number": 3}, + "vector": None, + }, + { + "id": 4, + "version": 0, + "score": 0.755163629383033, + "payload": {"color": "red", "rand_number": 4}, + "vector": None, + }, + { + "id": 0, + "version": 0, + "score": 0.73420337995255, + "payload": {"color": "red", "rand_number": 0}, + "vector": None, + }, + { + "id": 9, + "version": 0, + "score": 0.7127610621127889, + "payload": {"color": "red", "rand_number": 9}, + "vector": None, + }, + ] + results = qdrant_store.search("Book", query=[1.0, 1.0], return_vector=True) + assert results == [ + { + "id": 2, + "version": 0, + "score": 0.999106722578389, + "payload": {"color": "red", "rand_number": 2}, + "vector": [0.7363563179969788, 0.6765939593315125], + }, + { + "id": 7, + "version": 0, + "score": 0.9961650411397226, + "payload": {"color": "red", "rand_number": 7}, + "vector": [0.7662628889083862, 0.6425272226333618], + }, + { + "id": 1, + "version": 0, + "score": 0.9946351526856256, + "payload": {"color": "red", "rand_number": 1}, + "vector": [0.7764601111412048, 0.6301664113998413], + }, + { + "id": 5, + "version": 0, + "score": 0.9297466022881021, + "payload": {"color": "red", "rand_number": 5}, + "vector": [0.39707326889038086, 0.9177868962287903], + }, + { + "id": 8, + "version": 0, + "score": 0.9100373450784073, + "payload": {"color": "red", "rand_number": 8}, + "vector": [0.35037919878959656, 0.9366079568862915], + }, + { + "id": 6, + "version": 0, + "score": 0.7944306996390111, + "payload": {"color": "red", "rand_number": 6}, + "vector": [0.13228265941143036, 0.991212010383606], + }, + { + "id": 3, + "version": 0, + "score": 0.7723528053480722, + "payload": {"color": "red", "rand_number": 3}, + "vector": [0.9952857494354248, 0.0969860628247261], + }, + { + "id": 4, + "version": 0, + "score": 0.755163629383033, + "payload": {"color": "red", "rand_number": 4}, + "vector": [0.9975154995918274, 0.07044714689254761], + }, + { + "id": 0, + "version": 0, + "score": 0.73420337995255, + "payload": {"color": "red", "rand_number": 0}, + "vector": [0.9992359280586243, 0.03908444941043854], + }, + { + "id": 9, + "version": 0, + "score": 0.7127610621127889, + "payload": {"color": "red", "rand_number": 9}, + "vector": [0.9999677538871765, 0.00802854634821415], + }, + ] + results = qdrant_store.search( + "Book", + query=[1.0, 1.0], + query_filter=Filter( + must=[FieldCondition(key="rand_number", range=Range(gte=8))] + ), + ) + assert results == [ + { + "id": 8, + "version": 0, + "score": 0.9100373450784073, + "payload": {"color": "red", "rand_number": 8}, + "vector": None, + }, + { + "id": 9, + "version": 0, + "score": 0.7127610621127889, + "payload": {"color": "red", "rand_number": 9}, + "vector": None, + }, + ] + results = qdrant_store.search( + "Book", + query=[1.0, 1.0], + query_filter=Filter( + must=[FieldCondition(key="rand_number", range=Range(gte=8))] + ), + return_vector=True, + ) + assert results == [ + { + "id": 8, + "version": 0, + "score": 0.9100373450784073, + "payload": {"color": "red", "rand_number": 8}, + "vector": [0.35037919878959656, 0.9366079568862915], + }, + { + "id": 9, + "version": 0, + "score": 0.7127610621127889, + "payload": {"color": "red", "rand_number": 9}, + "vector": [0.9999677538871765, 0.00802854634821415], + }, + ]