diff --git a/metagpt/document_store/base_store.py b/metagpt/document_store/base_store.py index 01877e106..3dc96c0d6 100644 --- a/metagpt/document_store/base_store.py +++ b/metagpt/document_store/base_store.py @@ -15,7 +15,7 @@ class BaseStore(ABC): """FIXME: consider add_index, set_index and think 颗粒度""" @abstractmethod - def search(self, query, *args, **kwargs): + def search(self, *args, **kwargs): raise NotImplementedError @abstractmethod diff --git a/metagpt/document_store/qdrant_store.py b/metagpt/document_store/qdrant_store.py new file mode 100644 index 000000000..98b82cf87 --- /dev/null +++ b/metagpt/document_store/qdrant_store.py @@ -0,0 +1,129 @@ +from dataclasses import dataclass +from typing import List + +from qdrant_client import QdrantClient +from qdrant_client.models import Filter, PointStruct, VectorParams + +from metagpt.document_store.base_store import BaseStore + + +@dataclass +class QdrantConnection: + """ + Args: + url: qdrant url + host: qdrant host + port: qdrant port + memory: qdrant service use memory mode + api_key: qdrant cloud api_key + """ + url: str = None + host: str = None + port: int = None + memory: bool = False + api_key: str = None + + +class QdrantStore(BaseStore): + def __init__(self, connect: QdrantConnection): + if connect.memory: + self.client = QdrantClient(":memory:") + elif connect.url: + self.client = QdrantClient(url=connect.url, api_key=connect.api_key) + elif connect.host and connect.port: + self.client = QdrantClient( + host=connect.host, port=connect.port, api_key=connect.api_key + ) + else: + raise Exception("please check QdrantConnection.") + + def create_collection( + self, + collection_name: str, + vectors_config: VectorParams, + force_recreate=False, + **kwargs, + ): + """ + create a collection + Args: + collection_name: collection name + vectors_config: VectorParams object,detail in https://github.com/qdrant/qdrant-client + force_recreate: default is False, if True, will delete exists collection,then create it + **kwargs: + + Returns: + + """ + try: + self.client.get_collection(collection_name) + if force_recreate: + res = self.client.recreate_collection( + collection_name, vectors_config=vectors_config, **kwargs + ) + return res + return True + except: # noqa: E722 + return self.client.recreate_collection( + collection_name, vectors_config=vectors_config, **kwargs + ) + + def has_collection(self, collection_name: str): + try: + self.client.get_collection(collection_name) + return True + except: # noqa: E722 + return False + + def delete_collection(self, collection_name: str, timeout=60): + res = self.client.delete_collection(collection_name, timeout=timeout) + if not res: + raise Exception(f"Delete collection {collection_name} failed.") + + def add(self, collection_name: str, points: List[PointStruct]): + """ + add some vector data to qdrant + Args: + collection_name: collection name + points: list of PointStruct object, about PointStruct detail in https://github.com/qdrant/qdrant-client + + Returns: NoneX + + """ + # self.client.upload_records() + self.client.upsert( + collection_name, + points, + ) + + def search( + self, + collection_name: str, + query: List[float], + query_filter: Filter = None, + k=10, + return_vector=False, + ): + """ + vector search + Args: + collection_name: qdrant collection name + query: input vector + query_filter: Filter object, detail in https://github.com/qdrant/qdrant-client + k: return the most similar k pieces of data + return_vector: whether return vector + + Returns: list of dict + + """ + hits = self.client.search( + collection_name=collection_name, + query_vector=query, + query_filter=query_filter, + limit=k, + with_vectors=return_vector, + ) + return [hit.__dict__ for hit in hits] + + def write(self, *args, **kwargs): + pass diff --git a/requirements.txt b/requirements.txt index 452e2d092..c18145b98 100644 --- a/requirements.txt +++ b/requirements.txt @@ -17,7 +17,7 @@ numpy==1.24.3 openai==0.27.8 openpyxl pandas==1.4.1 -pydantic==1.10.7 +pydantic==1.10.8 #pygame==2.1.3 #pymilvus==2.2.8 pytest==7.2.2 @@ -36,3 +36,4 @@ anthropic==0.3.6 typing-inspect==0.8.0 typing_extensions==4.5.0 libcst==1.0.1 +qdrant-client==1.4.0 \ No newline at end of file diff --git a/tests/metagpt/document_store/test_qdrant_store.py b/tests/metagpt/document_store/test_qdrant_store.py new file mode 100644 index 000000000..a63a4329d --- /dev/null +++ b/tests/metagpt/document_store/test_qdrant_store.py @@ -0,0 +1,77 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2023/6/11 21:08 +@Author : hezhaozhao +@File : test_qdrant_store.py +""" +import random + +from qdrant_client.models import ( + Distance, + FieldCondition, + Filter, + PointStruct, + Range, + VectorParams, +) + +from metagpt.document_store.qdrant_store import QdrantConnection, QdrantStore + +seed_value = 42 +random.seed(seed_value) + +vectors = [[random.random() for _ in range(2)] for _ in range(10)] + +points = [ + PointStruct( + id=idx, vector=vector, payload={"color": "red", "rand_number": idx % 10} + ) + for idx, vector in enumerate(vectors) +] + + +def test_milvus_store(): + qdrant_connection = QdrantConnection(memory=True) + vectors_config = VectorParams(size=2, distance=Distance.COSINE) + qdrant_store = QdrantStore(qdrant_connection) + qdrant_store.create_collection("Book", vectors_config, force_recreate=True) + assert qdrant_store.has_collection("Book") is True + qdrant_store.delete_collection("Book") + assert qdrant_store.has_collection("Book") is False + qdrant_store.create_collection("Book", vectors_config) + assert qdrant_store.has_collection("Book") is True + qdrant_store.add("Book", points) + results = qdrant_store.search("Book", query=[1.0, 1.0]) + assert results[0]["id"] == 2 + assert results[0]["score"] == 0.999106722578389 + assert results[1]["score"] == 7 + assert results[1]["score"] == 0.9961650411397226 + results = qdrant_store.search("Book", query=[1.0, 1.0], return_vector=True) + assert results[0]["id"] == 2 + assert results[0]["score"] == 0.999106722578389 + assert results[0]["vector"] == [0.7363563179969788, 0.6765939593315125] + assert results[1]["score"] == 7 + assert results[1]["score"] == 0.9961650411397226 + assert results[1]["vector"] == [0.7662628889083862, 0.6425272226333618] + results = qdrant_store.search( + "Book", + query=[1.0, 1.0], + query_filter=Filter( + must=[FieldCondition(key="rand_number", range=Range(gte=8))] + ), + ) + assert results[0]["id"] == 8 + assert results[0]["score"] == 0.9100373450784073 + assert results[1]["id"] == 9 + assert results[1]["score"] == 0.7127610621127889 + results = qdrant_store.search( + "Book", + query=[1.0, 1.0], + query_filter=Filter( + must=[FieldCondition(key="rand_number", range=Range(gte=8))] + ), + return_vector=True, + ) + assert results[0]["vector"] == [0.35037919878959656, 0.9366079568862915] + assert results[1]["vector"] == [0.9999677538871765, 0.00802854634821415]